diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c9448f329..95b946ae4 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -48,7 +48,7 @@ workflow: when: never - if: $CI_MERGE_REQUEST_TITLE !~ /Draft:/ variables: - ROCPRIM_TEST_RUNS: 3 + ROCPRIM_TEST_RUNS: 1 - if: $CI_MERGE_REQUEST_TITLE =~ /Draft:/ variables: ROCPRIM_TEST_RUNS: 1 @@ -124,7 +124,7 @@ copyright-date: -D AMDGPU_TARGETS=$GPU_TARGETS -D CMAKE_C_COMPILER_LAUNCHER=phc_sccache_c -D CMAKE_CXX_COMPILER_LAUNCHER=phc_sccache_cxx - -D CMAKE_CXX_STANDARD=14 + -D CMAKE_CXX_STANDARD=17 -S $CI_PROJECT_DIR -B $BUILD_DIR - cmake @@ -174,6 +174,7 @@ build:cmake-minimum-apt: - .rules:build variables: EXTRA_CMAKE_CXX_FLAGS: "" + BUILD_TOOL_ARGS: "" script: - mkdir -p $BUILD_DIR - cd $BUILD_DIR @@ -196,7 +197,7 @@ build:cmake-minimum-apt: -D CMAKE_CXX_STANDARD="$BUILD_VERSION" -S $CI_PROJECT_DIR -B $BUILD_DIR - - cmake --build $BUILD_DIR + - cmake --build $BUILD_DIR -- ${BUILD_TOOL_ARGS} artifacts: paths: - $BUILD_DIR/.ninja_log @@ -213,6 +214,31 @@ build:cmake-minimum-apt: - $BUILD_DIR/test/test_* expire_in: 1 day +build:spirv: + stage: build + needs: [] + extends: + - .cmake-minimum + - .build:common + variables: + # For unknown reasons spir-v builds ignore 'clang diagnostic' pragmas that + # we use to ignore internal deprecations. + EXTRA_CMAKE_CXX_FLAGS: "-Wno-deprecated-declarations -mf16c -DROCPRIM_EXPERIMENTAL_SPIRV" + # Since not all targets are expected to build, do not stop building other + # targets when any target fails. + BUILD_TOOL_ARGS: "-k 0" + GPU_TARGETS: "amdgcnspirv" + image: "registry.streamhpc.internal/unstable-rocm:main" + allow_failure: true + parallel: + # Debug builds disabled due to excessive build times for debug test builds + matrix: + - BUILD_TYPE: Release + BUILD_TARGET: [BENCHMARK, TEST] + BUILD_VERSION: 17 + artifacts: + when: always + build:cmake-latest: stage: build needs: [] @@ -224,7 +250,7 @@ build:cmake-latest: matrix: - BUILD_TYPE: Release BUILD_TARGET: [BENCHMARK, TEST] - BUILD_VERSION: [14, 17] + BUILD_VERSION: 17 build:cmake-minimum: needs: [] @@ -235,7 +261,7 @@ build:cmake-minimum: matrix: - BUILD_TYPE: [Debug, Release] BUILD_TARGET: [BENCHMARK, TEST] - BUILD_VERSION: 14 + BUILD_VERSION: 17 build:package: stage: build @@ -252,7 +278,7 @@ build:package: -G Ninja -D CMAKE_CXX_COMPILER="$AMDCLANG" -D CMAKE_BUILD_TYPE=Release - -D CMAKE_CXX_STANDARD=14 + -D CMAKE_CXX_STANDARD=17 -B $PACKAGE_DIR -S $CI_PROJECT_DIR - cd $PACKAGE_DIR @@ -285,7 +311,7 @@ build:windows: -D CMAKE_CXX_COMPILER:PATH="${env:HIP_PATH}\bin\clang++.exe" -D CMAKE_PREFIX_PATH:PATH="${env:HIP_PATH}" -D CMAKE_BUILD_TYPE="$BUILD_TYPE" - -D CMAKE_CXX_STANDARD=14 + -D CMAKE_CXX_STANDARD=17 - cmake --build "$CI_PROJECT_DIR/build" artifacts: paths: @@ -332,7 +358,7 @@ autotune:build: -D AMDGPU_TARGETS=$GPU_TARGETS -D CMAKE_C_COMPILER_LAUNCHER=phc_sccache_c -D CMAKE_CXX_COMPILER_LAUNCHER=phc_sccache_cxx - -D CMAKE_CXX_STANDARD=14 + -D CMAKE_CXX_STANDARD=17 - cmake --build . --target $BENCHMARK_TARGETS - 'rm -rf $BUILD_DIR/benchmark/benchmark*.parallel' # The autotune benchmarks get very large, above GitLabs upload limit. Fortunately they compress well. @@ -359,7 +385,7 @@ autotune:build: matrix: - BUILD_TYPE: Release BUILD_TARGET: TEST - BUILD_VERSION: 14 + BUILD_VERSION: 17 script: - cd $BUILD_DIR - cmake @@ -398,6 +424,62 @@ test:all-gpus: - .test:common - .rules:test +.test:common-spirv: + stage: test + tags: + - rocm + - $GPU + extends: + - .cmake-minimum + allow_failure: true + timeout: 3h + needs: + - job: build:spirv + parallel: + matrix: + - BUILD_TYPE: Release + BUILD_TARGET: TEST + BUILD_VERSION: 17 + image: "registry.streamhpc.internal/unstable-rocm:main" + script: + - cd $BUILD_DIR + - cmake + -D CMAKE_PREFIX_PATH=/opt/rocm + -P $CI_PROJECT_DIR/cmake/GenerateResourceSpec.cmake + - cat ./resources.json + # Parallel execution (with other AMDGPU processes) can oversubscribe the SDMA queue. + # This causes the hipMemcpy to fail, which is not reported as an error by HIP. + # As a temporary workaround, disable the SDMA for test stability. + - HSA_ENABLE_SDMA=0 ctest + --output-on-failure + --repeat-until-fail 2 + --resource-spec-file ./resources.json + --parallel $PARALLEL_JOBS + --exclude-regex rocprim.device_partition + +test:any-gpu-spirv: + variables: + GPU: "" + PARALLEL_JOBS: 1 + extends: + - .test:common-spirv + rules: + - if: $CI_MERGE_REQUEST_TITLE =~ /Draft:/ && $CI_MERGE_REQUEST_LABELS !~ /Arch::/ + +test:label-arch-spirv: + extends: + - .gpus:rocm + - .test:common-spirv + - .rules:arch-labels + +test:all-gpus-spirv: + variables: + SHOULD_BE_UNDRAFTED: "true" + extends: + - .gpus:rocm + - .test:common-spirv + - .rules:test + .test-windows-base: stage: test extends: @@ -437,7 +519,7 @@ test-windows-release: -D CMAKE_CXX_COMPILER="$AMDCLANG" -D CMAKE_BUILD_TYPE=Release -D AMDGPU_TARGETS=$GPU_TARGETS - -D CMAKE_CXX_STANDARD=14 + -D CMAKE_CXX_STANDARD=17 -S "$CI_PROJECT_DIR/test/extra" -B "$CI_PROJECT_DIR/package_test" - cmake --build "$CI_PROJECT_DIR/package_test" @@ -459,7 +541,7 @@ test:install: -G Ninja -D CMAKE_CXX_COMPILER="$AMDCLANG" -D CMAKE_BUILD_TYPE=Release - -D CMAKE_CXX_STANDARD=14 + -D CMAKE_CXX_STANDARD=17 -B build -S $CI_PROJECT_DIR # Preserve $PATH when sudoing @@ -507,7 +589,7 @@ benchmark: matrix: - BUILD_TYPE: Release BUILD_TARGET: BENCHMARK - BUILD_VERSION: 14 + BUILD_VERSION: 17 extends: - .cmake-minimum - .gpus:rocm diff --git a/CHANGELOG.md b/CHANGELOG.md index f405848e9..bffcceb3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,19 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projects/rocPRIM/en/latest/](https://rocm.docs.amd.com/projects/rocPRIM/en/latest/). -## rocPRIM 3.6.0 for ROCm 7.0 +## rocPRIM 4.0.0 for ROCm 7.0 + +### Added + +* Added `rocprim::accumulator_t` to ensure parity with CCCL. +* Added test for `rocprim::accumulator_t` +* Added `rocprim::invoke_result_r` to ensure parity with CCCL. +* Added function `is_build_in` into `rocprim::traits::get`. +* Added virtual shared memory as a fallback option in `rocprim::device_merge` when it exceeds shared memory capacity, similar to `rocprim::device_select`, `rocprim::device_partition`, and `rocprim::device_merge_sort`, which already include this feature. +* Added initial value support to device level inclusive scans. +* Added new optimization to the backend for `device_transform` when the input and output are pointers. +* Added `LoadType` to `transform_config`, which is used for the `device_transform` when the input and output are pointers. +* Added `rocprim:device_transform` for n-ary transform operations API with as input `n` number of iterators inside a `rocprim::tuple`. ### Changed @@ -24,6 +36,62 @@ This is a complete list of affected functions and how their default accumulator * past default: `class AccType = detail::input_type_t>` * new default: `class AccType = rocprim::invoke_result_binary_op_t, BinaryFunction>` +* Changed the parameters `long_radix_bits` and `LongRadixBits` from `segmented_radix_sort` to `radix_bits` and `RadixBits` respectively. +* Marked the initialisation constructor of `rocprim::reverse_iterator` `explicit`, use `rocprim::make_reverse_iterator`. +* Merged `radix_key_codec` into type_traits system. +* Renamed `type_traits_interface.hpp` to `type_traits.hpp`, rename the original `type_traits.hpp` to `type_traits_functions.hpp`. +* Changed the default accumulator type for various device-level scan algorithms: + * `rocprim::inclusive_scan` + * Previous default: `class AccType = typename std::iterator_traits::value_type>` + * Current default: `class AccType = rocprim::accumulator_t::value_type>` + * `rocprim::deterministic_inclusive_scan` + * Previous default: `class AccType = typename std::iterator_traits::value_type>` + * Current default: `class AccType = rocprim::accumulator_t::value_type>` + * `rocprim::exclusive_scan` + * Previous default: `class AccType = detail::input_type_t>` + * Current default: `class AccType = rocprim::accumulator_t>` + * `rocprim::deterministic_exclusive_scan` + * Previous default: `class AccType = detail::input_type_t>` + * Current default: `class AccType = rocprim::accumulator_t>` + +### Deprecations + +* `rocprim::invoke_result_binary_op` and `rocprim::invoke_result_binary_op_t` are deprecated. Use `rocprim::accumulator_t` now. + +### Removed + +* Removed `rocprim::detail::float_bit_mask` and relative tests, use `rocprim::traits::float_bit_mask` instead. +* Removed `rocprim::traits::is_fundamental`, please use `rocprim::traits::get::is_fundamental()` directly. +* Removed the deprecated parameters `short_radix_bits` and `ShortRadixBits` from the `segmented_radix_sort` config. They were unused, it is only an API change. +* Removed the deprecated `operator<<` from the iterators. +* Removed the deprecated `TwiddleIn` and `TwiddleOut`. Use `radix_key_codec` instead. +* Removed the deprecated flags API of `block_adjacent_difference`. Use `subtract_left()` or `block_discontinuity::flag_heads()` instead. +* Removed the deprecated `to_exclusive` functions in the warp scans. +* Removed the `rocprim::load_cs` from the `cache_load_modifier` enum. Use `rocprim::load_nontemporal` instead. +* Removed the `rocprim::store_cs` from the `cache_store_modifier` enum. Use `rocprim::store_nontemporal` instead. +* Removed the deprecated header file `rocprim/detail/match_result_type.hpp`. Include `rocprim/type_traits.hpp` instead. + * This header included `rocprim::detail::invoke_result`. Use `rocprim::invoke_result` instead. + * This header included `rocprim::detail::invoke_result_binary_op`. Use `rocprim::invoke_result_binary_op` instead. + * This header included `rocprim::detail::match_result_type`. Use `rocprim::invoke_result_binary_op_t` instead. +* Removed the deprecated `rocprim::detail::radix_key_codec` function. Use `rocprim::radix_key_codec` instead. +* Removed `rocprim/detail/radix_sort.hpp`, functionality can now be found in `rocprim/thread/radix_key_codec.hpp`. +* Removed C++14 support, only C++17 is supported. +* Due to the removal of `__AMDGCN_WAVEFRONT_SIZE` in the compiler, the following deprecated warp size-related symbols have been removed: + * `rocprim::device_warp_size()` + * For compile-time constants, this is replaced with `rocprim::arch::wavefront::min_size()` and `rocprim::arch::wavefront::max_size()`. Use this when allocating global or shared memory. + * For run-time constants, this is replaced with `rocprim::arch::wavefront::size().` + * `rocprim::warp_size()` + * Use `rocprim::host_warp_size()`, `rocprim::arch::wavefront::min_size()` or `rocprim::arch::wavefront::max_size()` instead. + * `ROCPRIM_WAVEFRONT_SIZE` + * Use `rocprim::arch::wavefront::min_size()` or `rocprim::arch::wavefront::max_size()` instead. + * `__AMDGCN_WAVEFRONT_SIZE` + * This was a fallback define for the compiler's removed symbol, having the same name. + +### Resolved issues + +* Fixed an issue where `device_batch_memcpy` reported benchmarking throughput being 2x lower than it was in reality. +* Fixed an issue where `device_segmented_reduce` reported autotuning throughput being 5x lower than it was in reality. + ## rocPRIM 3.5.0 for ROCm 6.5.0 ### Removed @@ -40,6 +108,7 @@ This is a complete list of affected functions and how their default accumulator * Added the `rocprim::merge_inplace` function for merging in-place. * Added initial value support for warp- and block-level inclusive scan. * Added support for building tests with device-side random data generation, making them finish faster. This requires rocRAND, and is enabled with the `WITH_ROCRAND=ON` build flag. +* Added tests and documentation to `lookback_scan_state`. It is still in the `detail` namespace. ### Changed @@ -599,3 +668,5 @@ The following is the complete list of affected functions and how their default a * Switched to HIP-Clang as the default compiler * CMake searches for rocPRIM locally first; if t's not found, CMake downloads it from GitHub + + diff --git a/CMakeLists.txt b/CMakeLists.txt index 7a7e48373..b58e7e9be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,11 +41,8 @@ set(CMAKE_HIP_STANDARD 14) set(CMAKE_HIP_STANDARD_REQUIRED ON) set(CMAKE_HIP_EXTENSIONS OFF) -# Set CXX standard -if (CMAKE_CXX_STANDARD EQUAL 14) - message(WARNING "C++14 will be deprecated in the next major release") -elseif(NOT CMAKE_CXX_STANDARD EQUAL 17) - message(FATAL_ERROR "Only C++14 and C++17 are supported") +if(NOT CMAKE_CXX_STANDARD EQUAL 17) + message(FATAL_ERROR "Only C++17 is supported") endif() if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) @@ -183,7 +180,7 @@ if(BUILD_CODE_COVERAGE) endif() # Setup VERSION -set(VERSION_STRING "3.5.0") +set(VERSION_STRING "4.0.0") rocm_setup_version(VERSION ${VERSION_STRING}) math(EXPR rocprim_VERSION_NUMBER "${rocprim_VERSION_MAJOR} * 100000 + ${rocprim_VERSION_MINOR} * 100 + ${rocprim_VERSION_PATCH}") diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 472b7df54..a45a6dd1f 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -174,6 +174,7 @@ add_rocprim_benchmark(benchmark_device_segmented_radix_sort_keys.cpp) add_rocprim_benchmark(benchmark_device_segmented_radix_sort_pairs.cpp) add_rocprim_benchmark(benchmark_device_segmented_reduce.cpp) add_rocprim_benchmark(benchmark_device_transform.cpp) +add_rocprim_benchmark(benchmark_device_transform_pointer.cpp) add_rocprim_benchmark(benchmark_predicate_iterator.cpp) add_rocprim_benchmark(benchmark_warp_exchange.cpp) add_rocprim_benchmark(benchmark_warp_reduce.cpp) diff --git a/benchmark/ConfigAutotuneSettings.cmake b/benchmark/ConfigAutotuneSettings.cmake index e3aa88ba3..acd4f1292 100644 --- a/benchmark/ConfigAutotuneSettings.cmake +++ b/benchmark/ConfigAutotuneSettings.cmake @@ -85,18 +85,23 @@ ${TUNING_TYPES};${LIMITED_TUNING_TYPES};using_warp_scan reduce_then_scan" PARENT set(list_across "\ binary_search upper_bound lower_bound;${TUNING_TYPES};${LIMITED_TUNING_TYPES};64 128 256;1 2 4 8 16" PARENT_SCOPE) set(output_pattern_suffix "@SubAlgorithm@_@ValueType@_@OutputType@_@BlockSize@_@ItemsPerThread@" PARENT_SCOPE) + elseif(file STREQUAL "benchmark_device_search_n") + set(list_across_names "InputType;BlockSize;ItemsPerThread;Threshold" PARENT_SCOPE) + set(list_across "\ +${TUNING_TYPES};64 128 256 512 1024;1 2 4 8 16;4 8 12 16" PARENT_SCOPE) + set(output_pattern_suffix "@InputType@_@BlockSize@_@ItemsPerThread@_@Threshold@" PARENT_SCOPE) elseif(file STREQUAL "benchmark_device_segmented_radix_sort_keys") set(list_across_names "\ -KeyType;LongBits;BlockSize;ItemsPerThread;WarpSmallLWS;WarpSmallIPT;WarpSmallBS;WarpPartition;WarpMediumLWS;WarpMediumIPT;WarpMediumBS" PARENT_SCOPE) +KeyType;RadixBits;BlockSize;ItemsPerThread;WarpSmallLWS;WarpSmallIPT;WarpSmallBS;WarpPartition;WarpMediumLWS;WarpMediumIPT;WarpMediumBS" PARENT_SCOPE) set(list_across "${TUNING_TYPES};8;256;4 8 16;8;4;256;64;16;8;256" PARENT_SCOPE) set(output_pattern_suffix "\ -@KeyType@_@LongBits@_@BlockSize@_@ItemsPerThread@_@WarpSmallLWS@_@WarpSmallIPT@_@WarpSmallBS@_@WarpPartition@_@WarpMediumLWS@_@WarpMediumIPT@_@WarpMediumBS@" PARENT_SCOPE) +@KeyType@_@RadixBits@_@BlockSize@_@ItemsPerThread@_@WarpSmallLWS@_@WarpSmallIPT@_@WarpSmallBS@_@WarpPartition@_@WarpMediumLWS@_@WarpMediumIPT@_@WarpMediumBS@" PARENT_SCOPE) elseif(file STREQUAL "benchmark_device_segmented_radix_sort_pairs") set(list_across_names "\ -KeyType;ValueType;LongBits;BlockSize;ItemsPerThread;WarpSmallLWS;WarpSmallIPT;WarpSmallBS;WarpPartition;WarpMediumLWS;WarpMediumIPT;WarpMediumBS" PARENT_SCOPE) +KeyType;ValueType;RadixBits;BlockSize;ItemsPerThread;WarpSmallLWS;WarpSmallIPT;WarpSmallBS;WarpPartition;WarpMediumLWS;WarpMediumIPT;WarpMediumBS" PARENT_SCOPE) set(list_across "${TUNING_TYPES};${LIMITED_TUNING_TYPES};8;256;4 8 16;8;4;256;64;16;8;256" PARENT_SCOPE) set(output_pattern_suffix "\ -@KeyType@_@ValueType@_@LongBits@_@BlockSize@_@ItemsPerThread@_@WarpSmallLWS@_@WarpSmallIPT@_@WarpSmallBS@_@WarpPartition@_@WarpMediumLWS@_@WarpMediumIPT@_@WarpMediumBS@" PARENT_SCOPE) +@KeyType@_@ValueType@_@RadixBits@_@BlockSize@_@ItemsPerThread@_@WarpSmallLWS@_@WarpSmallIPT@_@WarpSmallBS@_@WarpPartition@_@WarpMediumLWS@_@WarpMediumIPT@_@WarpMediumBS@" PARENT_SCOPE) elseif(file STREQUAL "benchmark_device_segmented_reduce") set(list_across_names "DataType;BlockSize;ItemsPerThread" PARENT_SCOPE) set(list_across "\ @@ -108,6 +113,12 @@ DataType;BlockSize;" PARENT_SCOPE) set(list_across "${TUNING_TYPES};64 128 256 512 1024" PARENT_SCOPE) set(output_pattern_suffix "\ @DataType@_@BlockSize@" PARENT_SCOPE) +elseif(file STREQUAL "benchmark_device_transform_pointer") + set(list_across_names "\ +DataType;BlockSize;LoadType" PARENT_SCOPE) + set(list_across "${TUNING_TYPES};64 128 256 512 1024;rocprim::load_default rocprim::load_nontemporal" PARENT_SCOPE) + set(output_pattern_suffix "\ +@DataType@_@BlockSize@_@LoadType@" PARENT_SCOPE) elseif(file STREQUAL "benchmark_device_partition") set(list_across_names "DataType;BlockSize" PARENT_SCOPE) set(list_across "${TUNING_TYPES};128 192 256 384 512" PARENT_SCOPE) diff --git a/benchmark/benchmark_block_adjacent_difference.cpp b/benchmark/benchmark_block_adjacent_difference.cpp index d9f7728b4..71af4c10f 100644 --- a/benchmark/benchmark_block_adjacent_difference.cpp +++ b/benchmark/benchmark_block_adjacent_difference.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,11 +21,8 @@ // SOFTWARE. #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" -// Google Benchmark -#include +#include "../common/utils_device_ptr.hpp" // HIP API #include @@ -45,10 +42,6 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; -#endif - template -auto run_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +auto run_benchmark(benchmark_utils::state&& state) -> std::enable_if_t::value && !std::is_same::value> { + const auto& bytes = state.bytes; + const auto& seed = state.seed; + const auto& stream = state.stream; + // Calculate the number of elements N size_t N = bytes / sizeof(T); @@ -265,51 +259,20 @@ auto run_benchmark(benchmark::State& state, const std::vector input = get_random_data(size, random_range.first, random_range.second, seed.get_0()); - T* d_input; - T* d_output; - HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(input[0]))); - HIP_CHECK(hipMalloc(&d_output, input.size() * sizeof(T))); - HIP_CHECK( - hipMemcpy(d_input, input.data(), input.size() * sizeof(input[0]), hipMemcpyHostToDevice)); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel), - dim3(num_blocks), - dim3(BlockSize), - 0, - stream, - d_input, - d_output, - Trials); - HIP_CHECK(hipGetLastError()); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } + common::device_ptr d_input(input); + common::device_ptr d_output(input.size()); - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * Trials * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * Trials * size); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + state.run( + [&] + { + kernel + <<>>(d_input.get(), + d_output.get(), + Trials); + HIP_CHECK(hipGetLastError()); + }); + + state.set_throughput(size * Trials, sizeof(T)); } template -auto run_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +auto run_benchmark(benchmark_utils::state&& state) -> std::enable_if_t::value || std::is_same::value> { + const auto& bytes = state.bytes; + const auto& seed = state.seed; + const auto& stream = state.stream; + // Calculate the number of elements N size_t N = bytes / sizeof(T); @@ -345,166 +309,73 @@ auto run_benchmark(benchmark::State& state, random_range_tile_sizes.second, seed.get_1()); - T* d_input; - unsigned int* d_tile_sizes; - T* d_output; - HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(input[0]))); - HIP_CHECK(hipMalloc(&d_tile_sizes, tile_sizes.size() * sizeof(tile_sizes[0]))); - HIP_CHECK(hipMalloc(&d_output, input.size() * sizeof(input[0]))); - HIP_CHECK( - hipMemcpy(d_input, input.data(), input.size() * sizeof(input[0]), hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(d_tile_sizes, - tile_sizes.data(), - tile_sizes.size() * sizeof(tile_sizes[0]), - hipMemcpyHostToDevice)); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel), - dim3(num_blocks), - dim3(BlockSize), - 0, - stream, - d_input, - d_tile_sizes, - d_output, - Trials); - HIP_CHECK(hipGetLastError()); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * Trials * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * Trials * size); + common::device_ptr d_input(input); + common::device_ptr d_tile_sizes(tile_sizes); + common::device_ptr d_output(input.size()); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_tile_sizes)); - HIP_CHECK(hipFree(d_output)); + state.run( + [&] + { + kernel + <<>>(d_input.get(), + d_tile_sizes.get(), + d_output.get(), + Trials); + HIP_CHECK(hipGetLastError()); + }); + + state.set_throughput(size * Trials, sizeof(T)); } #define CREATE_BENCHMARK(T, BS, IPT, WITH_TILE) \ - benchmark::RegisterBenchmark( \ + executor.queue_fn( \ bench_naming::format_name("{lvl:block,algo:adjacent_difference,subalgo:" + name \ + ",key_type:" #T ",cfg:{bs:" #BS ",ipt:" #IPT \ ",with_tile:" #WITH_TILE "}}") \ .c_str(), \ - run_benchmark, \ - bytes, \ - seed, \ - stream) + run_benchmark); -#define BENCHMARK_TYPE(type, block, with_tile) \ - CREATE_BENCHMARK(type, block, 1, with_tile), CREATE_BENCHMARK(type, block, 3, with_tile), \ - CREATE_BENCHMARK(type, block, 4, with_tile), CREATE_BENCHMARK(type, block, 8, with_tile), \ - CREATE_BENCHMARK(type, block, 16, with_tile), CREATE_BENCHMARK(type, block, 32, with_tile) +#define BENCHMARK_TYPE(type, block, with_tile) \ + CREATE_BENCHMARK(type, block, 1, with_tile) \ + CREATE_BENCHMARK(type, block, 3, with_tile) \ + CREATE_BENCHMARK(type, block, 4, with_tile) \ + CREATE_BENCHMARK(type, block, 8, with_tile) \ + CREATE_BENCHMARK(type, block, 16, with_tile) \ + CREATE_BENCHMARK(type, block, 32, with_tile) template -void add_benchmarks(const std::string& name, - std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +void add_benchmarks(const std::string& name, benchmark_utils::executor& executor) { - std::vector bs - = {BENCHMARK_TYPE(int, 256, false), - BENCHMARK_TYPE(float, 256, false), - BENCHMARK_TYPE(int8_t, 256, false), - BENCHMARK_TYPE(rocprim::half, 256, false), - BENCHMARK_TYPE(long long, 256, false), - BENCHMARK_TYPE(double, 256, false), - BENCHMARK_TYPE(rocprim::int128_t, 256, false), - BENCHMARK_TYPE(rocprim::uint128_t, 256, false)}; + BENCHMARK_TYPE(int, 256, false) + BENCHMARK_TYPE(float, 256, false) + BENCHMARK_TYPE(int8_t, 256, false) + BENCHMARK_TYPE(rocprim::half, 256, false) + BENCHMARK_TYPE(long long, 256, false) + BENCHMARK_TYPE(double, 256, false) + BENCHMARK_TYPE(rocprim::int128_t, 256, false) + BENCHMARK_TYPE(rocprim::uint128_t, 256, false) if(!std::is_same::value) { - bs.insert(bs.end(), - {BENCHMARK_TYPE(int, 256, true), - BENCHMARK_TYPE(float, 256, true), - BENCHMARK_TYPE(int8_t, 256, true), - BENCHMARK_TYPE(rocprim::half, 256, true), - BENCHMARK_TYPE(long long, 256, true), - BENCHMARK_TYPE(double, 256, true), - BENCHMARK_TYPE(rocprim::int128_t, 256, true), - BENCHMARK_TYPE(rocprim::uint128_t, 256, true)}); + BENCHMARK_TYPE(int, 256, true) + BENCHMARK_TYPE(float, 256, true) + BENCHMARK_TYPE(int8_t, 256, true) + BENCHMARK_TYPE(rocprim::half, 256, true) + BENCHMARK_TYPE(long long, 256, true) + BENCHMARK_TYPE(double, 256, true) + BENCHMARK_TYPE(rocprim::int128_t, 256, true) + BENCHMARK_TYPE(rocprim::uint128_t, 256, true) } - - benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); } int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks; - add_benchmarks("subtract_left", benchmarks, bytes, seed, stream); - add_benchmarks("subtract_right", benchmarks, bytes, seed, stream); - add_benchmarks("subtract_left_partial", benchmarks, bytes, seed, stream); - add_benchmarks("subtract_right_partial", - benchmarks, - bytes, - seed, - stream); - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } + benchmark_utils::executor executor(argc, argv, 512 * benchmark_utils::MiB, 1, 0); - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + add_benchmarks("subtract_left", executor); + add_benchmarks("subtract_right", executor); + add_benchmarks("subtract_left_partial", executor); + add_benchmarks("subtract_right_partial", executor); - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_block_discontinuity.cpp b/benchmark/benchmark_block_discontinuity.cpp index 010def843..62fd01a17 100644 --- a/benchmark/benchmark_block_discontinuity.cpp +++ b/benchmark/benchmark_block_discontinuity.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,12 +20,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -// CmdParser #include "benchmark_utils.hpp" -#include "cmdparser.hpp" - -// Google Benchmark -#include // HIP API #include @@ -44,10 +39,6 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; -#endif - template -void run_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +void run_benchmark(benchmark_utils::state&& state) { + const auto& bytes = state.bytes; + const auto& seed = state.seed; + const auto& stream = state.stream; + // Calculate the number of elements N size_t N = bytes / sizeof(T); @@ -220,139 +212,60 @@ void run_benchmark(benchmark::State& state, HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - hipLaunchKernelGGL( - HIP_KERNEL_NAME(kernel), - dim3(size / items_per_block), - dim3(BlockSize), - 0, - stream, - d_input, - d_output); - HIP_CHECK(hipGetLastError()); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + state.run( + [&] + { + kernel + <<>>(d_input, d_output); + HIP_CHECK(hipGetLastError()); + }); - state.SetBytesProcessed(state.iterations() * Trials * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * Trials * size); + state.set_throughput(size * Trials, sizeof(T)); HIP_CHECK(hipFree(d_input)); HIP_CHECK(hipFree(d_output)); } -#define CREATE_BENCHMARK(T, BS, IPT, WITH_TILE) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:block,algo:discontinuity,subalgo:" + name \ - + ",key_type:" #T ",cfg:{bs:" #BS ",ipt:" #IPT \ - ",with_tile:" #WITH_TILE "}}") \ - .c_str(), \ - run_benchmark, \ - bytes, \ - seed, \ - stream) - -#define BENCHMARK_TYPE(type, block, bool) \ - CREATE_BENCHMARK(type, block, 1, bool), CREATE_BENCHMARK(type, block, 2, bool), \ - CREATE_BENCHMARK(type, block, 3, bool), CREATE_BENCHMARK(type, block, 4, bool), \ - CREATE_BENCHMARK(type, block, 8, bool) +#define CREATE_BENCHMARK(T, BS, IPT, WITH_TILE) \ + executor.queue_fn(bench_naming::format_name("{lvl:block,algo:discontinuity,subalgo:" + name \ + + ",key_type:" #T ",cfg:{bs:" #BS ",ipt:" #IPT \ + ",with_tile:" #WITH_TILE "}}") \ + .c_str(), \ + run_benchmark); + +#define BENCHMARK_TYPE(type, block, bool) \ + CREATE_BENCHMARK(type, block, 1, bool) \ + CREATE_BENCHMARK(type, block, 2, bool) \ + CREATE_BENCHMARK(type, block, 3, bool) \ + CREATE_BENCHMARK(type, block, 4, bool) \ + CREATE_BENCHMARK(type, block, 8, bool) template -void add_benchmarks(const std::string& name, - std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +void add_benchmarks(const std::string& name, benchmark_utils::executor& executor) { - std::vector bs - = {BENCHMARK_TYPE(int, 256, false), - BENCHMARK_TYPE(int, 256, true), - BENCHMARK_TYPE(int8_t, 256, false), - BENCHMARK_TYPE(int8_t, 256, true), - BENCHMARK_TYPE(uint8_t, 256, false), - BENCHMARK_TYPE(uint8_t, 256, true), - BENCHMARK_TYPE(rocprim::half, 256, false), - BENCHMARK_TYPE(rocprim::half, 256, true), - BENCHMARK_TYPE(long long, 256, false), - BENCHMARK_TYPE(long long, 256, true), - BENCHMARK_TYPE(rocprim::int128_t, 256, false), - BENCHMARK_TYPE(rocprim::int128_t, 256, true), - BENCHMARK_TYPE(rocprim::uint128_t, 256, false), - BENCHMARK_TYPE(rocprim::uint128_t, 256, true)}; - - benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); + BENCHMARK_TYPE(int, 256, false) + BENCHMARK_TYPE(int, 256, true) + BENCHMARK_TYPE(int8_t, 256, false) + BENCHMARK_TYPE(int8_t, 256, true) + BENCHMARK_TYPE(uint8_t, 256, false) + BENCHMARK_TYPE(uint8_t, 256, true) + BENCHMARK_TYPE(rocprim::half, 256, false) + BENCHMARK_TYPE(rocprim::half, 256, true) + BENCHMARK_TYPE(long long, 256, false) + BENCHMARK_TYPE(long long, 256, true) + BENCHMARK_TYPE(rocprim::int128_t, 256, false) + BENCHMARK_TYPE(rocprim::int128_t, 256, true) + BENCHMARK_TYPE(rocprim::uint128_t, 256, false) + BENCHMARK_TYPE(rocprim::uint128_t, 256, true) } int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks; - add_benchmarks("flag_heads", benchmarks, bytes, seed, stream); - add_benchmarks("flag_tails", benchmarks, bytes, seed, stream); - add_benchmarks("flag_heads_and_tails", benchmarks, bytes, seed, stream); - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } + benchmark_utils::executor executor(argc, argv, 512 * benchmark_utils::MiB, 1, 0); - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + add_benchmarks("flag_heads", executor); + add_benchmarks("flag_tails", executor); + add_benchmarks("flag_heads_and_tails", executor); - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_block_exchange.cpp b/benchmark/benchmark_block_exchange.cpp index 6a5e211a7..889fc7826 100644 --- a/benchmark/benchmark_block_exchange.cpp +++ b/benchmark/benchmark_block_exchange.cpp @@ -20,14 +20,10 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -// CmdParser #include "benchmark_utils.hpp" -#include "cmdparser.hpp" #include "../common/utils_custom_type.hpp" - -// Google Benchmark -#include +#include "../common/utils_device_ptr.hpp" // HIP API #include @@ -47,10 +43,6 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - template -void run_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +void run_benchmark(benchmark_utils::state&& state) { + const auto& bytes = state.bytes; + const auto& seed = state.seed; + const auto& stream = state.stream; + // Calculate the number of elements N size_t N = bytes / sizeof(T); @@ -241,159 +234,68 @@ void run_benchmark(benchmark::State& state, std::iota(block_ranks, block_ranks + items_per_block, 0); std::shuffle(block_ranks, block_ranks + items_per_block, gen); } - T* d_input; - unsigned int* d_ranks; - T* d_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(T))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_ranks), size * sizeof(unsigned int))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), size * sizeof(T))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(d_ranks, ranks.data(), size * sizeof(unsigned int), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); + common::device_ptr d_ranks(ranks); + common::device_ptr d_output(size); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel), - dim3(size / items_per_block), - dim3(BlockSize), - 0, - stream, - d_input, - d_ranks, - d_output); - HIP_CHECK(hipGetLastError()); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * Trials * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * Trials * size); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_ranks)); - HIP_CHECK(hipFree(d_output)); + state.run( + [&] + { + kernel + <<>>(d_input.get(), + d_ranks.get(), + d_output.get()); + HIP_CHECK(hipGetLastError()); + }); + + state.set_throughput(size * Trials, sizeof(T)); } -#define CREATE_BENCHMARK(T, BS, IPT) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:block,algo:exchange,subalgo:" + name \ - + ",key_type:" #T ",cfg:{bs:" #BS ",ipt:" #IPT "}}") \ - .c_str(), \ - run_benchmark, \ - bytes, \ - seed, \ - stream) - -#define BENCHMARK_TYPE(type, block) \ - CREATE_BENCHMARK(type, block, 1), CREATE_BENCHMARK(type, block, 2), \ - CREATE_BENCHMARK(type, block, 3), CREATE_BENCHMARK(type, block, 4), \ - CREATE_BENCHMARK(type, block, 7), CREATE_BENCHMARK(type, block, 8) +#define CREATE_BENCHMARK(T, BS, IPT) \ + executor.queue_fn(bench_naming::format_name("{lvl:block,algo:exchange,subalgo:" + name \ + + ",key_type:" #T ",cfg:{bs:" #BS ",ipt:" #IPT \ + "}}") \ + .c_str(), \ + run_benchmark); + +#define BENCHMARK_TYPE(type, block) \ + CREATE_BENCHMARK(type, block, 1) \ + CREATE_BENCHMARK(type, block, 2) \ + CREATE_BENCHMARK(type, block, 3) \ + CREATE_BENCHMARK(type, block, 4) \ + CREATE_BENCHMARK(type, block, 7) \ + CREATE_BENCHMARK(type, block, 8) template -void add_benchmarks(const std::string& name, - std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +void add_benchmarks(const std::string& name, benchmark_utils::executor& executor) { using custom_float2 = common::custom_type; using custom_double2 = common::custom_type; - std::vector bs = {BENCHMARK_TYPE(int, 256), - BENCHMARK_TYPE(int8_t, 256), - BENCHMARK_TYPE(rocprim::half, 256), - BENCHMARK_TYPE(long long, 256), - BENCHMARK_TYPE(custom_float2, 256), - BENCHMARK_TYPE(float2, 256), - BENCHMARK_TYPE(custom_double2, 256), - BENCHMARK_TYPE(double2, 256), - BENCHMARK_TYPE(float4, 256), - BENCHMARK_TYPE(rocprim::int128_t, 256), - BENCHMARK_TYPE(rocprim::uint128_t, 256)}; - - benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); + BENCHMARK_TYPE(int, 256) + BENCHMARK_TYPE(int8_t, 256) + BENCHMARK_TYPE(rocprim::half, 256) + BENCHMARK_TYPE(long long, 256) + BENCHMARK_TYPE(custom_float2, 256) + BENCHMARK_TYPE(float2, 256) + BENCHMARK_TYPE(custom_double2, 256) + BENCHMARK_TYPE(double2, 256) + BENCHMARK_TYPE(float4, 256) + BENCHMARK_TYPE(rocprim::int128_t, 256) + BENCHMARK_TYPE(rocprim::uint128_t, 256) } int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("bytes", "bytes", DEFAULT_BYTES, "number of values"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("bytes"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks; - add_benchmarks("blocked_to_striped", benchmarks, bytes, seed, stream); - add_benchmarks("striped_to_blocked", benchmarks, bytes, seed, stream); - add_benchmarks("blocked_to_warp_striped", - benchmarks, - bytes, - seed, - stream); - add_benchmarks("warp_striped_to_blocked", - benchmarks, - bytes, - seed, - stream); - add_benchmarks("scatter_to_blocked", benchmarks, bytes, seed, stream); - add_benchmarks("scatter_to_striped", benchmarks, bytes, seed, stream); - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 1, 0); - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + add_benchmarks("blocked_to_striped", executor); + add_benchmarks("striped_to_blocked", executor); + add_benchmarks("blocked_to_warp_striped", executor); + add_benchmarks("warp_striped_to_blocked", executor); + add_benchmarks("scatter_to_blocked", executor); + add_benchmarks("scatter_to_striped", executor); - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_block_histogram.cpp b/benchmark/benchmark_block_histogram.cpp index cf454b60c..b1083145a 100644 --- a/benchmark/benchmark_block_histogram.cpp +++ b/benchmark/benchmark_block_histogram.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,11 +21,8 @@ // SOFTWARE. #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" -// Google Benchmark -#include +#include "../common/utils_device_ptr.hpp" // HIP API #include @@ -35,14 +32,6 @@ #include #include -#include -#include -#include - -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; -#endif - template -void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) +void run_benchmark(benchmark_utils::state&& state) { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + // Calculate the number of elements N size_t N = bytes / sizeof(T); // Make sure size is a multiple of BlockSize @@ -120,169 +112,75 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) const auto bin_size = BinSize * ((N + items_per_block - 1) / items_per_block); // Allocate and fill memory std::vector input(size, 0.0f); - T* d_input; - T* d_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(T))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), bin_size * sizeof(T))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); + common::device_ptr d_output(bin_size); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - hipLaunchKernelGGL( - HIP_KERNEL_NAME(kernel), - dim3(size / items_per_block), - dim3(BlockSize), - 0, - stream, - d_input, - d_output); - HIP_CHECK(hipGetLastError()); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * size * sizeof(T) * Trials); - state.SetItemsProcessed(state.iterations() * size * Trials); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); -} - -// IPT - items per thread -#define CREATE_BENCHMARK(T, BS, IPT) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:block,algo:histogram,key_type:" #T ",cfg:{bs:" #BS \ - ",ipt:" #IPT ",method:" \ - + method_name + "}}") \ - .c_str(), \ - run_benchmark, \ - stream, \ - bytes) - -#define BENCHMARK_TYPE(type, block) \ - CREATE_BENCHMARK(type, block, 1), CREATE_BENCHMARK(type, block, 2), \ - CREATE_BENCHMARK(type, block, 3), CREATE_BENCHMARK(type, block, 4), \ - CREATE_BENCHMARK(type, block, 8), CREATE_BENCHMARK(type, block, 16) - -#define BENCHMARK_TYPE_128(type, block) \ - CREATE_BENCHMARK(type, block, 1), CREATE_BENCHMARK(type, block, 2), \ - CREATE_BENCHMARK(type, block, 3), CREATE_BENCHMARK(type, block, 4), \ - CREATE_BENCHMARK(type, block, 8), CREATE_BENCHMARK(type, block, 12) - -template< - typename Benchmark, - std::enable_if_t< - std::is_same>::value, - bool> - = true> -void add_benchmarks(std::vector& benchmarks, - const std::string& method_name, - hipStream_t stream, - size_t bytes) -{ - std::vector new_benchmarks - = {BENCHMARK_TYPE(int, 256), - BENCHMARK_TYPE(int, 320), - BENCHMARK_TYPE(int, 512), + state.run( + [&] + { + kernel + <<>>(d_input.get(), + d_output.get()); + HIP_CHECK(hipGetLastError()); + }); - BENCHMARK_TYPE(unsigned long long, 256), - BENCHMARK_TYPE(unsigned long long, 320)}; - benchmarks.insert(benchmarks.end(), new_benchmarks.begin(), new_benchmarks.end()); + state.set_throughput(size * Trials, sizeof(T)); } -template< - typename Benchmark, - std::enable_if_t< - std::is_same>::value, - bool> - = true> -void add_benchmarks(std::vector& benchmarks, - const std::string& method_name, - hipStream_t stream, - size_t bytes) -{ - std::vector new_benchmarks - = {BENCHMARK_TYPE(int, 256), - BENCHMARK_TYPE(int, 320), - BENCHMARK_TYPE(int, 512), - - BENCHMARK_TYPE(unsigned long long, 256), - BENCHMARK_TYPE(unsigned long long, 320), - - BENCHMARK_TYPE_128(rocprim::int128_t, 256), - BENCHMARK_TYPE_128(rocprim::uint128_t, 256)}; - benchmarks.insert(benchmarks.end(), new_benchmarks.begin(), new_benchmarks.end()); -} +#define CREATE_BENCHMARK(Benchmark, method, T, BS, IPT) \ + executor.queue_fn(bench_naming::format_name("{lvl:block,algo:histogram,key_type:" #T \ + ",cfg:{bs:" #BS ",ipt:" #IPT ",method:" \ + + std::string(method) + "}}") \ + .c_str(), \ + run_benchmark); + +#define BENCHMARK_TYPE(Benchmark, method, T, BS) \ + CREATE_BENCHMARK(Benchmark, method, T, BS, 1) \ + CREATE_BENCHMARK(Benchmark, method, T, BS, 2) \ + CREATE_BENCHMARK(Benchmark, method, T, BS, 3) \ + CREATE_BENCHMARK(Benchmark, method, T, BS, 4) \ + CREATE_BENCHMARK(Benchmark, method, T, BS, 8) \ + CREATE_BENCHMARK(Benchmark, method, T, BS, 16) + +#define BENCHMARK_TYPE_128(Benchmark, method, T, BS) \ + CREATE_BENCHMARK(Benchmark, method, T, BS, 1) \ + CREATE_BENCHMARK(Benchmark, method, T, BS, 2) \ + CREATE_BENCHMARK(Benchmark, method, T, BS, 3) \ + CREATE_BENCHMARK(Benchmark, method, T, BS, 4) \ + CREATE_BENCHMARK(Benchmark, method, T, BS, 8) \ + CREATE_BENCHMARK(Benchmark, method, T, BS, 12) + +#define BENCHMARK_ATOMIC() \ + BENCHMARK_TYPE(histogram_atomic_t, "using_atomic", int, 256) \ + BENCHMARK_TYPE(histogram_atomic_t, "using_atomic", int, 320) \ + BENCHMARK_TYPE(histogram_atomic_t, "using_atomic", int, 512) \ + \ + BENCHMARK_TYPE(histogram_atomic_t, "using_atomic", unsigned long long, 256) \ + BENCHMARK_TYPE(histogram_atomic_t, "using_atomic", unsigned long long, 320) + +#define BENCHMARK_SORT() \ + BENCHMARK_TYPE(histogram_sort_t, "using_sort", int, 256) \ + BENCHMARK_TYPE(histogram_sort_t, "using_sort", int, 320) \ + BENCHMARK_TYPE(histogram_sort_t, "using_sort", int, 512) \ + \ + BENCHMARK_TYPE(histogram_sort_t, "using_sort", unsigned long long, 256) \ + BENCHMARK_TYPE(histogram_sort_t, "using_sort", unsigned long long, 320) \ + \ + BENCHMARK_TYPE_128(histogram_sort_t, "using_sort", rocprim::int128_t, 256) \ + BENCHMARK_TYPE_128(histogram_sort_t, "using_sort", rocprim::uint128_t, 256) int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); + benchmark_utils::executor executor(argc, argv, 512 * benchmark_utils::MiB, 1, 0); - // Add benchmarks - std::vector benchmarks; - // using_atomic - using histogram_a_t = histogram; - add_benchmarks(benchmarks, "using_atomic", stream, bytes); - // using_sort - using histogram_s_t = histogram; - add_benchmarks(benchmarks, "using_sort", stream, bytes); +#ifndef BENCHMARK_CONFIG_TUNING + using histogram_atomic_t = histogram; + using histogram_sort_t = histogram; - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + BENCHMARK_ATOMIC() + BENCHMARK_SORT() +#endif - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_block_radix_rank.cpp b/benchmark/benchmark_block_radix_rank.cpp index 61faa3266..a70803779 100644 --- a/benchmark/benchmark_block_radix_rank.cpp +++ b/benchmark/benchmark_block_radix_rank.cpp @@ -21,13 +21,9 @@ // SOFTWARE. #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #include "../common/utils_data_generation.hpp" - -// Google Benchmark -#include +#include "../common/utils_device_ptr.hpp" // HIP API #include @@ -44,10 +40,6 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; -#endif - template -void run_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) + size_t Trials = 10> +void run_benchmark(benchmark_utils::state&& state) { + const auto& bytes = state.bytes; + const auto& seed = state.seed; + const auto& stream = state.stream; + // Calculate the number of elements N - size_t N = bytes / sizeof(T); - constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; - const unsigned int grid_size = ((N + items_per_block - 1) / items_per_block); - const unsigned int size = items_per_block * grid_size; + size_t N = bytes / sizeof(T); + constexpr size_t items_per_block = BlockSize * ItemsPerThread; + const size_t grid_size = ((N + items_per_block - 1) / items_per_block); + const size_t size = items_per_block * grid_size; std::vector input = get_random_data(size, common::generate_limits::min(), common::generate_limits::max(), seed.get_0()); - T* d_input; - unsigned int* d_output; - HIP_CHECK(hipMalloc(&d_input, size * sizeof(T))); - HIP_CHECK(hipMalloc(&d_output, size * sizeof(unsigned int))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); + common::device_ptr d_output(size); HIP_CHECK(hipDeviceSynchronize()); - for(auto _ : state) - { - auto start = std::chrono::steady_clock::now(); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(rank_kernel), - dim3(grid_size), - dim3(BlockSize), - 0, - stream, - d_input, - d_output); - HIP_CHECK(hipPeekAtLastError()); - HIP_CHECK(hipDeviceSynchronize()); - - auto end = std::chrono::steady_clock::now(); - auto elapsed_seconds - = std::chrono::duration_cast>(end - start); - state.SetIterationTime(elapsed_seconds.count()); - } - state.SetBytesProcessed(state.iterations() * Trials * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * Trials * size); + state.run( + [&] + { + rank_kernel + <<>>(d_input.get(), d_output.get()); + HIP_CHECK(hipGetLastError()); + }); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + state.set_throughput(size * Trials, sizeof(T)); } -#define CREATE_BENCHMARK(T, BS, IPT, KIND) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:block,algo:radix_rank,key_type:" #T ",cfg:{bs:" #BS \ - ",ipt:" #IPT ",method:" #KIND "}}") \ - .c_str(), \ - run_benchmark, \ - bytes, \ - seed, \ - stream) +#define CREATE_BENCHMARK(T, BS, IPT, KIND) \ + executor.queue_fn(bench_naming::format_name("{lvl:block,algo:radix_rank,key_type:" #T \ + ",cfg:{bs:" #BS ",ipt:" #IPT ",method:" #KIND \ + "}}") \ + .c_str(), \ + run_benchmark); // clang-format off -#define CREATE_BENCHMARK_KINDS(type, block, ipt) \ - CREATE_BENCHMARK(type, block, ipt, rocprim::block_radix_rank_algorithm::basic), \ - CREATE_BENCHMARK(type, block, ipt, rocprim::block_radix_rank_algorithm::basic_memoize), \ +#define CREATE_BENCHMARK_KINDS(type, block, ipt) \ + CREATE_BENCHMARK(type, block, ipt, rocprim::block_radix_rank_algorithm::basic) \ + CREATE_BENCHMARK(type, block, ipt, rocprim::block_radix_rank_algorithm::basic_memoize) \ CREATE_BENCHMARK(type, block, ipt, rocprim::block_radix_rank_algorithm::match) -#define BENCHMARK_TYPE(type, block) \ - CREATE_BENCHMARK_KINDS(type, block, 1), \ - CREATE_BENCHMARK_KINDS(type, block, 4), \ - CREATE_BENCHMARK_KINDS(type, block, 8), \ - CREATE_BENCHMARK_KINDS(type, block, 12), \ - CREATE_BENCHMARK_KINDS(type, block, 16), \ +#define BENCHMARK_TYPE(type, block) \ + CREATE_BENCHMARK_KINDS(type, block, 1) \ + CREATE_BENCHMARK_KINDS(type, block, 4) \ + CREATE_BENCHMARK_KINDS(type, block, 8) \ + CREATE_BENCHMARK_KINDS(type, block, 12) \ + CREATE_BENCHMARK_KINDS(type, block, 16) \ CREATE_BENCHMARK_KINDS(type, block, 20) // clang-format on -void add_benchmarks(std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) -{ - std::vector bs = {BENCHMARK_TYPE(int, 128), - BENCHMARK_TYPE(int, 256), - BENCHMARK_TYPE(int, 512), - - BENCHMARK_TYPE(uint8_t, 128), - BENCHMARK_TYPE(uint8_t, 256), - BENCHMARK_TYPE(uint8_t, 512), - - BENCHMARK_TYPE(long long, 128), - BENCHMARK_TYPE(long long, 256), - BENCHMARK_TYPE(long long, 512), - - BENCHMARK_TYPE(rocprim::int128_t, 128), - BENCHMARK_TYPE(rocprim::int128_t, 256), - BENCHMARK_TYPE(rocprim::int128_t, 512), - - BENCHMARK_TYPE(rocprim::uint128_t, 128), - BENCHMARK_TYPE(rocprim::uint128_t, 256), - BENCHMARK_TYPE(rocprim::uint128_t, 512)}; - - benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); -} - int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); + benchmark_utils::executor executor(argc, argv, 512 * benchmark_utils::MiB, 1, 0); - // HIP - hipStream_t stream = 0; // default + BENCHMARK_TYPE(int, 128) + BENCHMARK_TYPE(int, 256) + BENCHMARK_TYPE(int, 512) - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); + BENCHMARK_TYPE(uint8_t, 128) + BENCHMARK_TYPE(uint8_t, 256) + BENCHMARK_TYPE(uint8_t, 512) - // Add benchmarks - std::vector benchmarks; - add_benchmarks(benchmarks, bytes, seed, stream); + BENCHMARK_TYPE(long long, 128) + BENCHMARK_TYPE(long long, 256) + BENCHMARK_TYPE(long long, 512) - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } + BENCHMARK_TYPE(rocprim::int128_t, 128) + BENCHMARK_TYPE(rocprim::int128_t, 256) + BENCHMARK_TYPE(rocprim::int128_t, 512) - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + BENCHMARK_TYPE(rocprim::uint128_t, 128) + BENCHMARK_TYPE(rocprim::uint128_t, 256) + BENCHMARK_TYPE(rocprim::uint128_t, 512) - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_block_radix_sort.cpp b/benchmark/benchmark_block_radix_sort.cpp index a8f1122d2..331f80944 100644 --- a/benchmark/benchmark_block_radix_sort.cpp +++ b/benchmark/benchmark_block_radix_sort.cpp @@ -21,14 +21,10 @@ // SOFTWARE. #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #include "../common/utils_custom_type.hpp" #include "../common/utils_data_generation.hpp" - -// Google Benchmark -#include +#include "../common/utils_device_ptr.hpp" // HIP API #include @@ -46,10 +42,6 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; -#endif - enum class benchmark_kinds { sort_keys, @@ -126,16 +118,17 @@ void sort_pairs_kernel(const T* input, T* output) } template -void run_benchmark(benchmark::State& state, - benchmark_kinds benchmark_kind, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) + benchmark_kinds BenchmarkKind, + unsigned int BlockSize, + unsigned int RadixBitsPerPass, + unsigned int ItemsPerThread, + unsigned int Trials = 10> +void run_benchmark(benchmark_utils::state&& state) { + const auto& bytes = state.bytes; + const auto& seed = state.seed; + const auto& stream = state.stream; + // Calculate the number of elements N size_t N = bytes / sizeof(T); @@ -147,227 +140,137 @@ void run_benchmark(benchmark::State& state, common::generate_limits::max(), seed.get_0()); - T* d_input; - T* d_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(T))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), size * sizeof(T))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); + common::device_ptr d_output(size); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - if(benchmark_kind == benchmark_kinds::sort_keys) + state.run( + [&] { - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - sort_keys_kernel), - dim3(size / items_per_block), - dim3(BlockSize), - 0, - stream, - d_input, - d_output); - } - else if(benchmark_kind == benchmark_kinds::sort_pairs) - { - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - sort_pairs_kernel), - dim3(size / items_per_block), - dim3(BlockSize), - 0, - stream, - d_input, - d_output); - } - HIP_CHECK(hipGetLastError()); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * Trials * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * Trials * size); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + if constexpr(BenchmarkKind == benchmark_kinds::sort_keys) + { + sort_keys_kernel + <<>>(d_input.get(), + d_output.get()); + } + else if constexpr(BenchmarkKind == benchmark_kinds::sort_pairs) + { + sort_pairs_kernel + <<>>(d_input.get(), + d_output.get()); + } + HIP_CHECK(hipGetLastError()); + }); + + state.set_throughput(size * Trials, sizeof(T)); } #define CREATE_BENCHMARK(T, BS, RB, IPT) \ - benchmark::RegisterBenchmark( \ + executor.queue_fn( \ bench_naming::format_name("{lvl:block,algo:radix_sort,key_type:" #T ",subalgo:" + name \ + ",cfg:{bs:" #BS ",rb:" #RB ",ipt:" #IPT "}}") \ .c_str(), \ - run_benchmark, \ - benchmark_kind, \ - bytes, \ - seed, \ - stream) - -#define BENCHMARK_TYPE(type, block, radix_bits) \ - CREATE_BENCHMARK(type, block, radix_bits, 1), CREATE_BENCHMARK(type, block, radix_bits, 2), \ - CREATE_BENCHMARK(type, block, radix_bits, 3), \ - CREATE_BENCHMARK(type, block, radix_bits, 4), CREATE_BENCHMARK(type, block, radix_bits, 8) - -void add_benchmarks(benchmark_kinds benchmark_kind, - const std::string& name, - std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) + run_benchmark); + +#define BENCHMARK_TYPE(type, block, radix_bits) \ + CREATE_BENCHMARK(type, block, radix_bits, 1) \ + CREATE_BENCHMARK(type, block, radix_bits, 2) \ + CREATE_BENCHMARK(type, block, radix_bits, 3) \ + CREATE_BENCHMARK(type, block, radix_bits, 4) \ + CREATE_BENCHMARK(type, block, radix_bits, 8) + +template +void add_benchmarks(const std::string& name, benchmark_utils::executor& executor) { using custom_int_type = common::custom_type; - std::vector bs = {BENCHMARK_TYPE(int, 64, 3), - BENCHMARK_TYPE(int, 512, 3), - - BENCHMARK_TYPE(int, 64, 4), - BENCHMARK_TYPE(int, 128, 4), - BENCHMARK_TYPE(int, 192, 4), - BENCHMARK_TYPE(int, 256, 4), - BENCHMARK_TYPE(int, 320, 4), - BENCHMARK_TYPE(int, 512, 4), - - BENCHMARK_TYPE(int8_t, 64, 3), - BENCHMARK_TYPE(int8_t, 512, 3), - - BENCHMARK_TYPE(int8_t, 64, 4), - BENCHMARK_TYPE(int8_t, 128, 4), - BENCHMARK_TYPE(int8_t, 192, 4), - BENCHMARK_TYPE(int8_t, 256, 4), - BENCHMARK_TYPE(int8_t, 320, 4), - BENCHMARK_TYPE(int8_t, 512, 4), - - BENCHMARK_TYPE(uint8_t, 64, 3), - BENCHMARK_TYPE(uint8_t, 512, 3), - - BENCHMARK_TYPE(uint8_t, 64, 4), - BENCHMARK_TYPE(uint8_t, 128, 4), - BENCHMARK_TYPE(uint8_t, 192, 4), - BENCHMARK_TYPE(uint8_t, 256, 4), - BENCHMARK_TYPE(uint8_t, 320, 4), - BENCHMARK_TYPE(uint8_t, 512, 4), - - BENCHMARK_TYPE(rocprim::half, 64, 3), - BENCHMARK_TYPE(rocprim::half, 512, 3), - - BENCHMARK_TYPE(rocprim::half, 64, 4), - BENCHMARK_TYPE(rocprim::half, 128, 4), - BENCHMARK_TYPE(rocprim::half, 192, 4), - BENCHMARK_TYPE(rocprim::half, 256, 4), - BENCHMARK_TYPE(rocprim::half, 320, 4), - BENCHMARK_TYPE(rocprim::half, 512, 4), - - BENCHMARK_TYPE(long long, 64, 3), - BENCHMARK_TYPE(long long, 512, 3), - - BENCHMARK_TYPE(long long, 64, 4), - BENCHMARK_TYPE(long long, 128, 4), - BENCHMARK_TYPE(long long, 192, 4), - BENCHMARK_TYPE(long long, 256, 4), - BENCHMARK_TYPE(long long, 320, 4), - BENCHMARK_TYPE(long long, 512, 4), - - BENCHMARK_TYPE(custom_int_type, 64, 3), - BENCHMARK_TYPE(custom_int_type, 512, 3), - - BENCHMARK_TYPE(custom_int_type, 64, 4), - BENCHMARK_TYPE(custom_int_type, 128, 4), - BENCHMARK_TYPE(custom_int_type, 192, 4), - BENCHMARK_TYPE(custom_int_type, 256, 4), - BENCHMARK_TYPE(custom_int_type, 320, 4), - BENCHMARK_TYPE(custom_int_type, 512, 4), - - BENCHMARK_TYPE(rocprim::int128_t, 64, 3), - BENCHMARK_TYPE(rocprim::int128_t, 512, 3), - - BENCHMARK_TYPE(rocprim::int128_t, 64, 4), - BENCHMARK_TYPE(rocprim::int128_t, 128, 4), - BENCHMARK_TYPE(rocprim::int128_t, 192, 4), - BENCHMARK_TYPE(rocprim::int128_t, 256, 4), - BENCHMARK_TYPE(rocprim::int128_t, 320, 4), - BENCHMARK_TYPE(rocprim::int128_t, 512, 4), - - BENCHMARK_TYPE(rocprim::uint128_t, 64, 3), - BENCHMARK_TYPE(rocprim::uint128_t, 512, 3), - - BENCHMARK_TYPE(rocprim::uint128_t, 64, 4), - BENCHMARK_TYPE(rocprim::uint128_t, 128, 4), - BENCHMARK_TYPE(rocprim::uint128_t, 192, 4), - BENCHMARK_TYPE(rocprim::uint128_t, 256, 4), - BENCHMARK_TYPE(rocprim::uint128_t, 320, 4), - BENCHMARK_TYPE(rocprim::uint128_t, 512, 4)}; - - benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); + BENCHMARK_TYPE(int, 64, 3) + BENCHMARK_TYPE(int, 512, 3) + + BENCHMARK_TYPE(int, 64, 4) + BENCHMARK_TYPE(int, 128, 4) + BENCHMARK_TYPE(int, 192, 4) + BENCHMARK_TYPE(int, 256, 4) + BENCHMARK_TYPE(int, 320, 4) + BENCHMARK_TYPE(int, 512, 4) + + BENCHMARK_TYPE(int8_t, 64, 3) + BENCHMARK_TYPE(int8_t, 512, 3) + + BENCHMARK_TYPE(int8_t, 64, 4) + BENCHMARK_TYPE(int8_t, 128, 4) + BENCHMARK_TYPE(int8_t, 192, 4) + BENCHMARK_TYPE(int8_t, 256, 4) + BENCHMARK_TYPE(int8_t, 320, 4) + BENCHMARK_TYPE(int8_t, 512, 4) + + BENCHMARK_TYPE(uint8_t, 64, 3) + BENCHMARK_TYPE(uint8_t, 512, 3) + + BENCHMARK_TYPE(uint8_t, 64, 4) + BENCHMARK_TYPE(uint8_t, 128, 4) + BENCHMARK_TYPE(uint8_t, 192, 4) + BENCHMARK_TYPE(uint8_t, 256, 4) + BENCHMARK_TYPE(uint8_t, 320, 4) + BENCHMARK_TYPE(uint8_t, 512, 4) + + BENCHMARK_TYPE(rocprim::half, 64, 3) + BENCHMARK_TYPE(rocprim::half, 512, 3) + + BENCHMARK_TYPE(rocprim::half, 64, 4) + BENCHMARK_TYPE(rocprim::half, 128, 4) + BENCHMARK_TYPE(rocprim::half, 192, 4) + BENCHMARK_TYPE(rocprim::half, 256, 4) + BENCHMARK_TYPE(rocprim::half, 320, 4) + BENCHMARK_TYPE(rocprim::half, 512, 4) + + BENCHMARK_TYPE(long long, 64, 3) + BENCHMARK_TYPE(long long, 512, 3) + + BENCHMARK_TYPE(long long, 64, 4) + BENCHMARK_TYPE(long long, 128, 4) + BENCHMARK_TYPE(long long, 192, 4) + BENCHMARK_TYPE(long long, 256, 4) + BENCHMARK_TYPE(long long, 320, 4) + BENCHMARK_TYPE(long long, 512, 4) + + BENCHMARK_TYPE(custom_int_type, 64, 3) + BENCHMARK_TYPE(custom_int_type, 512, 3) + + BENCHMARK_TYPE(custom_int_type, 64, 4) + BENCHMARK_TYPE(custom_int_type, 128, 4) + BENCHMARK_TYPE(custom_int_type, 192, 4) + BENCHMARK_TYPE(custom_int_type, 256, 4) + BENCHMARK_TYPE(custom_int_type, 320, 4) + BENCHMARK_TYPE(custom_int_type, 512, 4) + + BENCHMARK_TYPE(rocprim::int128_t, 64, 3) + BENCHMARK_TYPE(rocprim::int128_t, 512, 3) + + BENCHMARK_TYPE(rocprim::int128_t, 64, 4) + BENCHMARK_TYPE(rocprim::int128_t, 128, 4) + BENCHMARK_TYPE(rocprim::int128_t, 192, 4) + BENCHMARK_TYPE(rocprim::int128_t, 256, 4) + BENCHMARK_TYPE(rocprim::int128_t, 320, 4) + BENCHMARK_TYPE(rocprim::int128_t, 512, 4) + + BENCHMARK_TYPE(rocprim::uint128_t, 64, 3) + BENCHMARK_TYPE(rocprim::uint128_t, 512, 3) + + BENCHMARK_TYPE(rocprim::uint128_t, 64, 4) + BENCHMARK_TYPE(rocprim::uint128_t, 128, 4) + BENCHMARK_TYPE(rocprim::uint128_t, 192, 4) + BENCHMARK_TYPE(rocprim::uint128_t, 256, 4) + BENCHMARK_TYPE(rocprim::uint128_t, 320, 4) + BENCHMARK_TYPE(rocprim::uint128_t, 512, 4) } int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks; - add_benchmarks(benchmark_kinds::sort_keys, "keys", benchmarks, bytes, seed, stream); - add_benchmarks(benchmark_kinds::sort_pairs, "pairs", benchmarks, bytes, seed, stream); - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } + benchmark_utils::executor executor(argc, argv, 512 * benchmark_utils::MiB, 1, 0); - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + add_benchmarks("keys", executor); + add_benchmarks("pairs", executor); - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_block_reduce.cpp b/benchmark/benchmark_block_reduce.cpp index 497c3cced..27678d0d7 100644 --- a/benchmark/benchmark_block_reduce.cpp +++ b/benchmark/benchmark_block_reduce.cpp @@ -21,13 +21,9 @@ // SOFTWARE. #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #include "../common/utils_custom_type.hpp" - -// Google Benchmark -#include +#include "../common/utils_device_ptr.hpp" // HIP API #include @@ -42,10 +38,6 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - template -void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) +void run_benchmark(benchmark_utils::state&& state) { + const auto& bytes = state.bytes; + const auto& stream = state.stream; + // Calculate the number of elements N size_t N = bytes / sizeof(T); // Make sure size is a multiple of BlockSize @@ -104,172 +99,96 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) const auto size = items_per_block * ((N + items_per_block - 1) / items_per_block); // Allocate and fill memory std::vector input(size, T(1)); - T* d_input; - T* d_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(T))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), size * sizeof(T))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); + common::device_ptr d_output(size); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel), - dim3(size / items_per_block), - dim3(BlockSize), - 0, - stream, - d_input, - d_output); - HIP_CHECK(hipGetLastError()); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * size * sizeof(T) * Trials); - state.SetItemsProcessed(state.iterations() * size * Trials); + state.run( + [&] + { + kernel + <<>>(d_input.get(), + d_output.get()); + HIP_CHECK(hipGetLastError()); + }); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + state.set_throughput(size * Trials, sizeof(T)); } -// IPT - items per thread -#define CREATE_BENCHMARK(T, BS, IPT) \ - benchmark::RegisterBenchmark(bench_naming::format_name("{lvl:block,algo:reduce,key_type:" #T \ - ",cfg:{bs:" #BS ",ipt:" #IPT ",method:" \ - + method_name + "}}") \ - .c_str(), \ - run_benchmark, \ - stream, \ - bytes) - -#define BENCHMARK_TYPE(type, block) \ - CREATE_BENCHMARK(type, block, 1), CREATE_BENCHMARK(type, block, 2), \ - CREATE_BENCHMARK(type, block, 3), CREATE_BENCHMARK(type, block, 4), \ - CREATE_BENCHMARK(type, block, 8), CREATE_BENCHMARK(type, block, 11), \ - CREATE_BENCHMARK(type, block, 16) +#define CREATE_BENCHMARK(T, BS, IPT) \ + executor.queue_fn(bench_naming::format_name("{lvl:block,algo:reduce,key_type:" #T \ + ",cfg:{bs:" #BS ",ipt:" #IPT ",method:" \ + + name + "}}") \ + .c_str(), \ + run_benchmark); + +#define BENCHMARK_TYPE(type, block) \ + CREATE_BENCHMARK(type, block, 1) \ + CREATE_BENCHMARK(type, block, 2) \ + CREATE_BENCHMARK(type, block, 3) \ + CREATE_BENCHMARK(type, block, 4) \ + CREATE_BENCHMARK(type, block, 8) \ + CREATE_BENCHMARK(type, block, 11) \ + CREATE_BENCHMARK(type, block, 16) template -void add_benchmarks(std::vector& benchmarks, - const std::string& method_name, - hipStream_t stream, - size_t bytes) +void add_benchmarks(const std::string& name, benchmark_utils::executor& executor) { using custom_float2 = common::custom_type; using custom_double2 = common::custom_type; - std::vector new_benchmarks - = {// When block size is less than or equal to warp size - BENCHMARK_TYPE(int, 64), - BENCHMARK_TYPE(float, 64), - BENCHMARK_TYPE(double, 64), - BENCHMARK_TYPE(int8_t, 64), - BENCHMARK_TYPE(uint8_t, 64), - BENCHMARK_TYPE(rocprim::half, 64), - BENCHMARK_TYPE(rocprim::int128_t, 64), - BENCHMARK_TYPE(rocprim::uint128_t, 64), - - BENCHMARK_TYPE(int, 256), - BENCHMARK_TYPE(float, 256), - BENCHMARK_TYPE(double, 256), - BENCHMARK_TYPE(int8_t, 256), - BENCHMARK_TYPE(uint8_t, 256), - BENCHMARK_TYPE(rocprim::half, 256), - BENCHMARK_TYPE(rocprim::int128_t, 256), - BENCHMARK_TYPE(rocprim::uint128_t, 256), - - CREATE_BENCHMARK(custom_float2, 256, 1), - CREATE_BENCHMARK(custom_float2, 256, 4), - CREATE_BENCHMARK(custom_float2, 256, 8), - - CREATE_BENCHMARK(float2, 256, 1), - CREATE_BENCHMARK(float2, 256, 4), - CREATE_BENCHMARK(float2, 256, 8), - - CREATE_BENCHMARK(custom_double2, 256, 1), - CREATE_BENCHMARK(custom_double2, 256, 4), - CREATE_BENCHMARK(custom_double2, 256, 8), - - CREATE_BENCHMARK(double2, 256, 1), - CREATE_BENCHMARK(double2, 256, 4), - CREATE_BENCHMARK(double2, 256, 8), - - CREATE_BENCHMARK(float4, 256, 1), - CREATE_BENCHMARK(float4, 256, 4), - CREATE_BENCHMARK(float4, 256, 8)}; - benchmarks.insert(benchmarks.end(), new_benchmarks.begin(), new_benchmarks.end()); + // When block size is less than or equal to warp size + BENCHMARK_TYPE(int, 64) + BENCHMARK_TYPE(float, 64) + BENCHMARK_TYPE(double, 64) + BENCHMARK_TYPE(int8_t, 64) + BENCHMARK_TYPE(uint8_t, 64) + BENCHMARK_TYPE(rocprim::half, 64) + BENCHMARK_TYPE(rocprim::int128_t, 64) + BENCHMARK_TYPE(rocprim::uint128_t, 64) + + BENCHMARK_TYPE(int, 256) + BENCHMARK_TYPE(float, 256) + BENCHMARK_TYPE(double, 256) + BENCHMARK_TYPE(int8_t, 256) + BENCHMARK_TYPE(uint8_t, 256) + BENCHMARK_TYPE(rocprim::half, 256) + BENCHMARK_TYPE(rocprim::int128_t, 256) + BENCHMARK_TYPE(rocprim::uint128_t, 256) + + CREATE_BENCHMARK(custom_float2, 256, 1) + CREATE_BENCHMARK(custom_float2, 256, 4) + CREATE_BENCHMARK(custom_float2, 256, 8) + + CREATE_BENCHMARK(float2, 256, 1) + CREATE_BENCHMARK(float2, 256, 4) + CREATE_BENCHMARK(float2, 256, 8) + + CREATE_BENCHMARK(custom_double2, 256, 1) + CREATE_BENCHMARK(custom_double2, 256, 4) + CREATE_BENCHMARK(custom_double2, 256, 8) + + CREATE_BENCHMARK(double2, 256, 1) + CREATE_BENCHMARK(double2, 256, 4) + CREATE_BENCHMARK(double2, 256, 8) + + CREATE_BENCHMARK(float4, 256, 1) + CREATE_BENCHMARK(float4, 256, 4) + CREATE_BENCHMARK(float4, 256, 8) } int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 1, 0); - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - - // Add benchmarks - std::vector benchmarks; - // using_warp_scan using reduce_uwr_t = reduce; - add_benchmarks(benchmarks, "using_warp_reduce", stream, bytes); - // reduce then scan - using reduce_rr_t = reduce; - add_benchmarks(benchmarks, "raking_reduce", stream, bytes); - // reduce commutative only - using reduce_rrco_t = reduce; - add_benchmarks(benchmarks, "raking_reduce_commutative_only", stream, bytes); + add_benchmarks("using_warp_reduce", executor); - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } + using reduce_rr_t = reduce; + add_benchmarks("raking_reduce", executor); - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + using reduce_rrco_t = reduce; + add_benchmarks("raking_reduce_commutative_only", executor); - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_block_run_length_decode.cpp b/benchmark/benchmark_block_run_length_decode.cpp index 81a713aa5..3b20d0f80 100644 --- a/benchmark/benchmark_block_run_length_decode.cpp +++ b/benchmark/benchmark_block_run_length_decode.cpp @@ -21,16 +21,14 @@ // SOFTWARE. #include "benchmark_utils.hpp" -#include "cmdparser.hpp" #include "../common/utils_data_generation.hpp" - -#include +#include "../common/utils_device_ptr.hpp" #include #include #include -#include +#include #include #include @@ -39,10 +37,6 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - template -void run_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +void run_benchmark(benchmark_utils::state&& state) { + const auto& bytes = state.bytes; + const auto& seed = state.seed; + const auto& stream = state.stream; + // Calculate the number of elements N size_t N = bytes / sizeof(ItemT); constexpr auto runs_per_block = BlockSize * RunsPerThread; @@ -133,147 +128,76 @@ void run_benchmark(benchmark::State& state, } const OffsetT output_length = run_offsets.back(); - ItemT* d_run_items{}; - HIP_CHECK(hipMalloc(&d_run_items, run_items.size() * sizeof(ItemT))); - HIP_CHECK(hipMemcpy(d_run_items, - run_items.data(), - run_items.size() * sizeof(ItemT), - hipMemcpyHostToDevice)); + common::device_ptr d_run_items(run_items); - OffsetT* d_run_offsets{}; - HIP_CHECK(hipMalloc(&d_run_offsets, run_offsets.size() * sizeof(OffsetT))); - HIP_CHECK(hipMemcpy(d_run_offsets, - run_offsets.data(), - run_offsets.size() * sizeof(OffsetT), - hipMemcpyHostToDevice)); + common::device_ptr d_run_offsets(run_offsets); - ItemT* d_output{}; - HIP_CHECK(hipMalloc(&d_output, output_length * sizeof(ItemT))); + common::device_ptr d_output(output_length); - for(auto _ : state) - { - auto start = std::chrono::steady_clock::now(); - hipLaunchKernelGGL(HIP_KERNEL_NAME(block_run_length_decode_kernel), - dim3(num_runs / runs_per_block), - dim3(BlockSize), - 0, - stream, - d_run_items, - d_run_offsets, - d_output); - HIP_CHECK(hipPeekAtLastError()); - HIP_CHECK(hipDeviceSynchronize()); - - auto end = std::chrono::steady_clock::now(); - auto elapsed_seconds - = std::chrono::duration_cast>(end - start); - - state.SetIterationTime(elapsed_seconds.count()); - } - state.SetBytesProcessed(state.iterations() * output_length * sizeof(ItemT) * Trials); - state.SetItemsProcessed(state.iterations() * output_length * Trials); - - HIP_CHECK(hipFree(d_run_items)); - HIP_CHECK(hipFree(d_run_offsets)); - HIP_CHECK(hipFree(d_output)); + state.run( + [&] + { + block_run_length_decode_kernel + <<>>( + d_run_items.get(), + d_run_offsets.get(), + d_output.get()); + HIP_CHECK(hipPeekAtLastError()); + HIP_CHECK(hipDeviceSynchronize()); + }); + + state.set_throughput(output_length * Trials, sizeof(ItemT)); } #define CREATE_BENCHMARK(IT, OT, MINRL, MAXRL, BS, RPT, DIPT) \ - benchmark::RegisterBenchmark( \ + executor.queue_fn( \ bench_naming::format_name("{lvl:block,algo:run_length_decode" \ ",item_type:" #IT ",offset_type:" #OT ",min_run_length:" #MINRL \ ",max_run_length:" #MAXRL ",cfg:{block_size:" #BS \ ",run_per_thread:" #RPT ",decoded_items_per_thread:" #DIPT "}}") \ .c_str(), \ - &run_benchmark, \ - bytes, \ - seed, \ - stream) + &run_benchmark); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks{ - CREATE_BENCHMARK(int, int, 1, 5, 128, 2, 4), - CREATE_BENCHMARK(int, int, 1, 10, 128, 2, 4), - CREATE_BENCHMARK(int, int, 1, 50, 128, 2, 4), - CREATE_BENCHMARK(int, int, 1, 100, 128, 2, 4), - CREATE_BENCHMARK(int, int, 1, 500, 128, 2, 4), - CREATE_BENCHMARK(int, int, 1, 1000, 128, 2, 4), - CREATE_BENCHMARK(int, int, 1, 5000, 128, 2, 4), - - CREATE_BENCHMARK(double, long long, 1, 5, 128, 2, 4), - CREATE_BENCHMARK(double, long long, 1, 10, 128, 2, 4), - CREATE_BENCHMARK(double, long long, 1, 50, 128, 2, 4), - CREATE_BENCHMARK(double, long long, 1, 100, 128, 2, 4), - CREATE_BENCHMARK(double, long long, 1, 500, 128, 2, 4), - CREATE_BENCHMARK(double, long long, 1, 1000, 128, 2, 4), - CREATE_BENCHMARK(double, long long, 1, 5000, 128, 2, 4), - - CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 5, 128, 2, 4), - CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 10, 128, 2, 4), - CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 50, 128, 2, 4), - CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 100, 128, 2, 4), - CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 500, 128, 2, 4), - CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 1000, 128, 2, 4), - CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 5000, 128, 2, 4), - - CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 5, 128, 2, 4), - CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 10, 128, 2, 4), - CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 50, 128, 2, 4), - CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 100, 128, 2, 4), - CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 500, 128, 2, 4), - CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 1000, 128, 2, 4), - CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 5000, 128, 2, 4)}; - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 1, 0); + + CREATE_BENCHMARK(int, int, 1, 5, 128, 2, 4) + CREATE_BENCHMARK(int, int, 1, 10, 128, 2, 4) + CREATE_BENCHMARK(int, int, 1, 50, 128, 2, 4) + CREATE_BENCHMARK(int, int, 1, 100, 128, 2, 4) + CREATE_BENCHMARK(int, int, 1, 500, 128, 2, 4) + CREATE_BENCHMARK(int, int, 1, 1000, 128, 2, 4) + CREATE_BENCHMARK(int, int, 1, 5000, 128, 2, 4) + + CREATE_BENCHMARK(double, long long, 1, 5, 128, 2, 4) + CREATE_BENCHMARK(double, long long, 1, 10, 128, 2, 4) + CREATE_BENCHMARK(double, long long, 1, 50, 128, 2, 4) + CREATE_BENCHMARK(double, long long, 1, 100, 128, 2, 4) + CREATE_BENCHMARK(double, long long, 1, 500, 128, 2, 4) + CREATE_BENCHMARK(double, long long, 1, 1000, 128, 2, 4) + CREATE_BENCHMARK(double, long long, 1, 5000, 128, 2, 4) + + CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 5, 128, 2, 4) + CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 10, 128, 2, 4) + CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 50, 128, 2, 4) + CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 100, 128, 2, 4) + CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 500, 128, 2, 4) + CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 1000, 128, 2, 4) + CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t, 1, 5000, 128, 2, 4) + + CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 5, 128, 2, 4) + CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 10, 128, 2, 4) + CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 50, 128, 2, 4) + CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 100, 128, 2, 4) + CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 500, 128, 2, 4) + CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 1000, 128, 2, 4) + CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t, 1, 5000, 128, 2, 4) + + executor.run(); } diff --git a/benchmark/benchmark_block_scan.cpp b/benchmark/benchmark_block_scan.cpp index 0e4e3b7f0..49933c9e0 100644 --- a/benchmark/benchmark_block_scan.cpp +++ b/benchmark/benchmark_block_scan.cpp @@ -21,13 +21,9 @@ // SOFTWARE. #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #include "../common/utils_custom_type.hpp" - -// Google Benchmark -#include +#include "../common/utils_device_ptr.hpp" // HIP API #include @@ -43,10 +39,6 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; -#endif - template -void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) +void run_benchmark(benchmark_utils::state&& state) { + const auto& bytes = state.bytes; + const auto& stream = state.stream; + // Calculate the number of elements N size_t N = bytes / sizeof(T); // Make sure size is a multiple of BlockSize @@ -137,197 +132,106 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) const auto size = items_per_block * ((N + items_per_block - 1) / items_per_block); // Allocate and fill memory std::vector input(size, T(1)); - T* d_input; - T* d_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(T))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), size * sizeof(T))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); + common::device_ptr d_output(size); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel), - dim3(size / items_per_block), - dim3(BlockSize), - 0, - stream, - d_input, - d_output); - HIP_CHECK(hipGetLastError()); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * size * sizeof(T) * Trials); - state.SetItemsProcessed(state.iterations() * size * Trials); + state.run( + [&] + { + kernel + <<>>(d_input.get(), + d_output.get()); + HIP_CHECK(hipGetLastError()); + }); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + state.set_throughput(size * Trials, sizeof(T)); } -// IPT - items per thread -#define CREATE_BENCHMARK(T, BS, IPT) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:block,algo:scan,subalgo:" + algorithm_name \ - + ",key_type:" #T ",cfg:{bs:" #BS ",ipt:" #IPT ",method:" \ - + method_name + "}}") \ - .c_str(), \ - run_benchmark, \ - stream, \ - bytes) - -#define BENCHMARK_TYPE(type, block) \ - CREATE_BENCHMARK(type, block, 1), CREATE_BENCHMARK(type, block, 2), \ - CREATE_BENCHMARK(type, block, 3), CREATE_BENCHMARK(type, block, 4), \ - CREATE_BENCHMARK(type, block, 8), CREATE_BENCHMARK(type, block, 11), \ - CREATE_BENCHMARK(type, block, 16) +#define CREATE_BENCHMARK(T, BS, IPT) \ + executor.queue_fn(bench_naming::format_name("{lvl:block,algo:scan,subalgo:" + algorithm_name \ + + ",key_type:" #T ",cfg:{bs:" #BS ",ipt:" #IPT \ + ",method:" \ + + method_name + "}}") \ + .c_str(), \ + run_benchmark); + +#define BENCHMARK_TYPE(type, block) \ + CREATE_BENCHMARK(type, block, 1) \ + CREATE_BENCHMARK(type, block, 2) \ + CREATE_BENCHMARK(type, block, 3) \ + CREATE_BENCHMARK(type, block, 4) \ + CREATE_BENCHMARK(type, block, 8) \ + CREATE_BENCHMARK(type, block, 11) \ + CREATE_BENCHMARK(type, block, 16) template -void add_benchmarks(std::vector& benchmarks, - const std::string& method_name, - const std::string& algorithm_name, - hipStream_t stream, - size_t bytes) +void add_benchmarks(const std::string& method_name, + const std::string& algorithm_name, + benchmark_utils::executor& executor) { using custom_float2 = common::custom_type; using custom_double2 = common::custom_type; - std::vector new_benchmarks - = {// When block size is less than or equal to warp size - BENCHMARK_TYPE(int, 64), - BENCHMARK_TYPE(float, 64), - BENCHMARK_TYPE(double, 64), - BENCHMARK_TYPE(int8_t, 64), - BENCHMARK_TYPE(uint8_t, 64), - BENCHMARK_TYPE(rocprim::half, 64), - - BENCHMARK_TYPE(int, 256), - BENCHMARK_TYPE(float, 256), - BENCHMARK_TYPE(double, 256), - BENCHMARK_TYPE(int8_t, 256), - BENCHMARK_TYPE(uint8_t, 256), - BENCHMARK_TYPE(rocprim::half, 256), - - CREATE_BENCHMARK(custom_float2, 256, 1), - CREATE_BENCHMARK(custom_float2, 256, 4), - CREATE_BENCHMARK(custom_float2, 256, 8), - - CREATE_BENCHMARK(float2, 256, 1), - CREATE_BENCHMARK(float2, 256, 4), - CREATE_BENCHMARK(float2, 256, 8), - - CREATE_BENCHMARK(custom_double2, 256, 1), - CREATE_BENCHMARK(custom_double2, 256, 4), - CREATE_BENCHMARK(custom_double2, 256, 8), - - CREATE_BENCHMARK(double2, 256, 1), - CREATE_BENCHMARK(double2, 256, 4), - CREATE_BENCHMARK(double2, 256, 8), - - CREATE_BENCHMARK(float4, 256, 1), - CREATE_BENCHMARK(float4, 256, 4), - CREATE_BENCHMARK(float4, 256, 8), - - CREATE_BENCHMARK(rocprim::int128_t, 256, 1), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4), - CREATE_BENCHMARK(rocprim::int128_t, 256, 8), - - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 8)}; - benchmarks.insert(benchmarks.end(), new_benchmarks.begin(), new_benchmarks.end()); + // When block size is less than or equal to warp size + BENCHMARK_TYPE(int, 64) + BENCHMARK_TYPE(float, 64) + BENCHMARK_TYPE(double, 64) + BENCHMARK_TYPE(int8_t, 64) + BENCHMARK_TYPE(uint8_t, 64) + BENCHMARK_TYPE(rocprim::half, 64) + + BENCHMARK_TYPE(int, 256) + BENCHMARK_TYPE(float, 256) + BENCHMARK_TYPE(double, 256) + BENCHMARK_TYPE(int8_t, 256) + BENCHMARK_TYPE(uint8_t, 256) + BENCHMARK_TYPE(rocprim::half, 256) + + CREATE_BENCHMARK(custom_float2, 256, 1) + CREATE_BENCHMARK(custom_float2, 256, 4) + CREATE_BENCHMARK(custom_float2, 256, 8) + + CREATE_BENCHMARK(float2, 256, 1) + CREATE_BENCHMARK(float2, 256, 4) + CREATE_BENCHMARK(float2, 256, 8) + + CREATE_BENCHMARK(custom_double2, 256, 1) + CREATE_BENCHMARK(custom_double2, 256, 4) + CREATE_BENCHMARK(custom_double2, 256, 8) + + CREATE_BENCHMARK(double2, 256, 1) + CREATE_BENCHMARK(double2, 256, 4) + CREATE_BENCHMARK(double2, 256, 8) + + CREATE_BENCHMARK(float4, 256, 1) + CREATE_BENCHMARK(float4, 256, 4) + CREATE_BENCHMARK(float4, 256, 8) + + CREATE_BENCHMARK(rocprim::int128_t, 256, 1) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4) + CREATE_BENCHMARK(rocprim::int128_t, 256, 8) + + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 8) } int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - - // Add benchmarks - std::vector benchmarks; - // inclusive_scan using_warp_scan + benchmark_utils::executor executor(argc, argv, 512 * benchmark_utils::MiB, 1, 0); + using inclusive_scan_uws_t = inclusive_scan; - add_benchmarks(benchmarks, - "inclusive_scan", - "using_warp_scan", - stream, - bytes); - // exclusive_scan using_warp_scan + add_benchmarks("inclusive_scan", "using_warp_scan", executor); + using exclusive_scan_uws_t = exclusive_scan; - add_benchmarks(benchmarks, - "exclusive_scan", - "using_warp_scan", - stream, - bytes); - // inclusive_scan reduce then scan + add_benchmarks("exclusive_scan", "using_warp_scan", executor); + using inclusive_scan_rts_t = inclusive_scan; - add_benchmarks(benchmarks, - "inclusive_scan", - "reduce_then_scan", - stream, - bytes); - // exclusive_scan reduce then scan - using exclusive_scan_rts_t = exclusive_scan; - add_benchmarks(benchmarks, - "exclusive_scan", - "reduce_then_scan", - stream, - bytes); - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } + add_benchmarks("inclusive_scan", "reduce_then_scan", executor); - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + using exclusive_scan_rts_t = exclusive_scan; + add_benchmarks("exclusive_scan", "reduce_then_scan", executor); - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_block_sort.cpp b/benchmark/benchmark_block_sort.cpp index 83c42c15f..ef0fab3c8 100644 --- a/benchmark/benchmark_block_sort.cpp +++ b/benchmark/benchmark_block_sort.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,13 +21,6 @@ // SOFTWARE. #include "benchmark_block_sort.parallel.hpp" -#include "benchmark_utils.hpp" - -// CmdParser -#include "cmdparser.hpp" - -// Google Benchmark -#include // HIP API #include @@ -45,38 +38,16 @@ #include #endif -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; -#endif +#define CREATE_BENCHMARK_IPT_ALG(K, V, BS, IPT, ALG) \ + benchmark_utils::executor::queue_sorted_instance< \ + block_sort_benchmark>(); \ + benchmark_utils::executor::queue_sorted_instance< \ + block_sort_benchmark>(); -#define CREATE_BENCHMARK_IPT(K, V, BS, IPT) \ - config_autotune_register::create< \ - block_sort_benchmark>(); \ - config_autotune_register::create< \ - block_sort_benchmark>(); \ - config_autotune_register::create< \ - block_sort_benchmark>(); \ - config_autotune_register::create< \ - block_sort_benchmark>(); \ - config_autotune_register::create< \ - block_sort_benchmark>(); \ - config_autotune_register::create< \ - block_sort_benchmark>(); +#define CREATE_BENCHMARK_IPT(K, V, BS, IPT) \ + CREATE_BENCHMARK_IPT_ALG(K, V, BS, IPT, rocprim::block_sort_algorithm::merge_sort) \ + CREATE_BENCHMARK_IPT_ALG(K, V, BS, IPT, rocprim::block_sort_algorithm::stable_merge_sort) \ + CREATE_BENCHMARK_IPT_ALG(K, V, BS, IPT, rocprim::block_sort_algorithm::bitonic_sort) #define CREATE_BENCHMARK(K, V, BS) \ CREATE_BENCHMARK_IPT(K, V, BS, 1) \ @@ -84,35 +55,9 @@ const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - const hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); + benchmark_utils::executor executor(argc, argv, 512 * benchmark_utils::MiB, 10, 0); -// If we are NOT config tuning run a selection of benchmarks -// Block sizes as large as possible ar most relevant -#ifndef BENCHMARK_CONFIG_TUNING + // Block sizes as large as possible are most relevant CREATE_BENCHMARK(float, rocprim::empty_type, 256) CREATE_BENCHMARK(double, rocprim::empty_type, 256) CREATE_BENCHMARK(rocprim::half, rocprim::empty_type, 256) @@ -129,28 +74,6 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(uint8_t, uint32_t, 512) CREATE_BENCHMARK(int64_t, rocprim::int128_t, 512) CREATE_BENCHMARK(uint64_t, rocprim::uint128_t, 512) -#endif - - std::vector benchmarks = {}; - config_autotune_register::register_benchmark_subset(benchmarks, 0, 1, bytes, seed, stream); - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_block_sort.parallel.hpp b/benchmark/benchmark_block_sort.parallel.hpp index c00c11a40..a40482788 100644 --- a/benchmark/benchmark_block_sort.parallel.hpp +++ b/benchmark/benchmark_block_sort.parallel.hpp @@ -26,6 +26,7 @@ #include "benchmark_utils.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -39,7 +40,7 @@ #include #include #include -#include +#include #include #include @@ -159,7 +160,7 @@ template -struct block_sort_benchmark : public config_autotune_interface +struct block_sort_benchmark : public benchmark_utils::autotune_interface { private: static constexpr bool with_values = !std::is_same::value; @@ -195,25 +196,14 @@ struct block_sort_benchmark : public config_autotune_interface + ",method:" + std::string(get_block_sort_method_name(block_sort_algorithm)) + "}}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - static constexpr bool debug_synchronous = false; - static auto dispatch_block_sort(std::false_type /*stable_sort*/, size_t size, const hipStream_t stream, KeyType* d_input, KeyType* d_output) { - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - sort_kernel), - dim3(size / items_per_block), - dim3(BlockSize), - 0, - stream, - d_input, - d_output); + sort_kernel + <<>>(d_input, d_output); } static auto dispatch_block_sort(std::true_type /*stable_sort*/, @@ -222,24 +212,16 @@ struct block_sort_benchmark : public config_autotune_interface KeyType* d_input, KeyType* d_output) { - hipLaunchKernelGGL(HIP_KERNEL_NAME(stable_sort_kernel), - dim3(size / items_per_block), - dim3(BlockSize), - 0, - stream, - d_input, - d_output); + stable_sort_kernel + <<>>(d_input, d_output); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements N size_t N = bytes / sizeof(KeyType); @@ -251,54 +233,21 @@ struct block_sort_benchmark : public config_autotune_interface common::generate_limits::max(), seed.get_0()); - KeyType* d_input; - KeyType* d_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(KeyType))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), size * sizeof(KeyType))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(KeyType), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); + common::device_ptr d_output(size); HIP_CHECK(hipDeviceSynchronize()); static constexpr auto stable_tag = rocprim::detail::bool_constant{}; - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - // Run - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) - { - dispatch_block_sort(stable_tag, size, stream, d_input, d_output); - } - HIP_CHECK(hipGetLastError()); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(KeyType)); - state.SetItemsProcessed(state.iterations() * batch_size * size); + state.run( + [&] { dispatch_block_sort(stable_tag, size, stream, d_input.get(), d_output.get()); }); - state.counters["sorted_size"] = benchmark::Counter(BlockSize * ItemsPerThread, - benchmark::Counter::kDefaults, - benchmark::Counter::OneK::kIs1024); + state.set_throughput(size, sizeof(KeyType)); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + state.gbench_state.counters["sorted_size"] + = benchmark::Counter(BlockSize * ItemsPerThread, + benchmark::Counter::kDefaults, + benchmark::Counter::OneK::kIs1024); } }; diff --git a/benchmark/benchmark_config_dispatch.cpp b/benchmark/benchmark_config_dispatch.cpp index e1e6eda0c..ed62ce5bd 100644 --- a/benchmark/benchmark_config_dispatch.cpp +++ b/benchmark/benchmark_config_dispatch.cpp @@ -1,20 +1,13 @@ #include "benchmark_utils.hpp" -#include "cmdparser.hpp" #include -#include - #include #include #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - enum class stream_kind { default_stream, @@ -23,12 +16,13 @@ enum class stream_kind async_stream }; -static void BM_host_target_arch(benchmark::State& state, const stream_kind stream_kind) +template +static void BM_host_target_arch(benchmark_utils::state&& state) { - const hipStream_t stream = [stream_kind]() -> hipStream_t + const hipStream_t stream = []() -> hipStream_t { hipStream_t stream = 0; - switch(stream_kind) + switch(StreamKind) { case stream_kind::default_stream: return stream; case stream_kind::per_thread_stream: return hipStreamPerThread; @@ -39,14 +33,17 @@ static void BM_host_target_arch(benchmark::State& state, const stream_kind strea } }(); - for(auto _ : state) - { - rocprim::detail::target_arch target_arch; - HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); - benchmark::DoNotOptimize(target_arch); - } + state.run( + [&] + { + rocprim::detail::target_arch target_arch; + HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); + benchmark::DoNotOptimize(target_arch); + }); - if(stream_kind != stream_kind::default_stream && stream_kind != stream_kind::per_thread_stream) + state.set_throughput(1, sizeof(char)); + + if(StreamKind != stream_kind::default_stream && StreamKind != stream_kind::per_thread_stream) { HIP_CHECK(hipStreamDestroy(stream)); } @@ -57,69 +54,37 @@ void empty_kernel() {} // An empty kernel launch for baseline -static void BM_kernel_launch(benchmark::State& state) +static void BM_kernel_launch(benchmark_utils::state&& state) { - static constexpr hipStream_t stream = 0; + const auto& stream = state.stream; - for(auto _ : state) - { - hipLaunchKernelGGL(empty_kernel, dim3(1), dim3(1), 0, stream); - HIP_CHECK(hipGetLastError()); - } - HIP_CHECK(hipStreamSynchronize(stream)); + state.run( + [&] + { + empty_kernel<<>>(); + HIP_CHECK(hipGetLastError()); + }); + + state.set_throughput(1, sizeof(char)); } -#define CREATE_BENCHMARK(ST, SK) \ - benchmark::RegisterBenchmark(bench_naming::format_name("{lvl:na" \ - ",algo:" #ST ",cfg:default_config}") \ - .c_str(), \ - &BM_host_target_arch, \ - SK) +#define CREATE_BENCHMARK(ST, SK) \ + executor.queue_fn( \ + bench_naming::format_name("{lvl:na,algo:" #ST ",cfg:default_config}").c_str(), \ + BM_host_target_arch); int main(int argc, char** argv) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", 100, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - - // HIP - - std::vector benchmarks{ - CREATE_BENCHMARK(default_stream, stream_kind::default_stream), - CREATE_BENCHMARK(per_thread_stream, stream_kind::per_thread_stream), - CREATE_BENCHMARK(explicit_stream, stream_kind::explicit_stream), - CREATE_BENCHMARK(async_stream, stream_kind::async_stream), - benchmark::RegisterBenchmark( - bench_naming::format_name("{lvl:na,algo:empty_kernel,cfg:default_config}").c_str(), - BM_kernel_launch)}; - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 1, 0, true, 100); - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + CREATE_BENCHMARK(default_stream, stream_kind::default_stream) + CREATE_BENCHMARK(per_thread_stream, stream_kind::per_thread_stream) + CREATE_BENCHMARK(explicit_stream, stream_kind::explicit_stream) + CREATE_BENCHMARK(async_stream, stream_kind::async_stream) + + executor.queue_fn( + bench_naming::format_name("{lvl:na,algo:empty_kernel,cfg:default_config}").c_str(), + BM_kernel_launch); - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_adjacent_difference.cpp b/benchmark/benchmark_device_adjacent_difference.cpp index 9dc0673a2..4313a1346 100644 --- a/benchmark/benchmark_device_adjacent_difference.cpp +++ b/benchmark/benchmark_device_adjacent_difference.cpp @@ -21,16 +21,12 @@ // SOFTWARE. #include "benchmark_device_adjacent_difference.parallel.hpp" -#include "benchmark_utils.hpp" #ifndef BENCHMARK_CONFIG_TUNING #include "../common/device_adjacent_difference.hpp" #include "../common/utils_custom_type.hpp" #endif -// Google Benchmark -#include - // HIP API #include @@ -39,9 +35,6 @@ #include #endif -// CmdParser -#include "cmdparser.hpp" - #include #include #include @@ -49,111 +42,40 @@ #include #endif -#ifndef DEFAULT_BYTES -constexpr std::size_t DEFAULT_BYTES = 1024LL * 1024LL * 1024LL * 2LL; -#endif - -#define CREATE_BENCHMARK(T, left, in_place) \ - { \ - const device_adjacent_difference_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ - } +#define CREATE_BENCHMARK(T, Left, Aliasing) \ + executor.queue_instance(device_adjacent_difference_benchmark()); // clang-format off -#define CREATE_BENCHMARKS(T) \ +#define CREATE_BENCHMARKS(T) \ CREATE_BENCHMARK(T, true, common::api_variant::no_alias) \ - CREATE_BENCHMARK(T, true, common::api_variant::in_place) \ + CREATE_BENCHMARK(T, true, common::api_variant::in_place) \ CREATE_BENCHMARK(T, false, common::api_variant::no_alias) \ CREATE_BENCHMARK(T, false, common::api_variant::in_place) // clang-format on int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "size in bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); + benchmark_utils::executor executor(argc, argv, 2 * benchmark_utils::GiB, 10, 5); - // HIP - const hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); - benchmark::AddCustomContext("seed", seed_type); - - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - size, - seed, - stream); -#else // BENCHMARK_CONFIG_TUNING +#ifndef BENCHMARK_CONFIG_TUNING using custom_float2 = common::custom_type; using custom_double2 = common::custom_type; - // Add benchmarks - CREATE_BENCHMARKS(int) - CREATE_BENCHMARKS(std::int64_t) - CREATE_BENCHMARKS(uint8_t) - CREATE_BENCHMARKS(rocprim::half) + CREATE_BENCHMARKS(int); + CREATE_BENCHMARKS(std::int64_t); - CREATE_BENCHMARKS(float) - CREATE_BENCHMARKS(double) + CREATE_BENCHMARKS(uint8_t); + CREATE_BENCHMARKS(rocprim::half); - CREATE_BENCHMARKS(custom_float2) - CREATE_BENCHMARKS(custom_double2) + CREATE_BENCHMARKS(float); + CREATE_BENCHMARKS(double); - CREATE_BENCHMARKS(rocprim::int128_t) - CREATE_BENCHMARKS(rocprim::uint128_t) -#endif // BENCHMARK_CONFIG_TUNING + CREATE_BENCHMARKS(custom_float2); + CREATE_BENCHMARKS(custom_double2); - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); + CREATE_BENCHMARKS(rocprim::int128_t); + CREATE_BENCHMARKS(rocprim::uint128_t); +#endif - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_adjacent_difference.parallel.cpp.in b/benchmark/benchmark_device_adjacent_difference.parallel.cpp.in index 804005bcf..039ac6b96 100644 --- a/benchmark/benchmark_device_adjacent_difference.parallel.cpp.in +++ b/benchmark/benchmark_device_adjacent_difference.parallel.cpp.in @@ -30,7 +30,7 @@ #include namespace { - auto benchmarks = config_autotune_register::create_bulk( + auto unused = benchmark_utils::executor::queue_autotune( device_adjacent_difference_benchmark_generator< @DataType@, @BlockSize@, diff --git a/benchmark/benchmark_device_adjacent_difference.parallel.hpp b/benchmark/benchmark_device_adjacent_difference.parallel.hpp index 3f7f2e9ac..051a92700 100644 --- a/benchmark/benchmark_device_adjacent_difference.parallel.hpp +++ b/benchmark/benchmark_device_adjacent_difference.parallel.hpp @@ -26,6 +26,7 @@ #include "benchmark_utils.hpp" #include "../common/device_adjacent_difference.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -65,9 +66,8 @@ template -struct device_adjacent_difference_benchmark : public config_autotune_interface +struct device_adjacent_difference_benchmark : public benchmark_utils::autotune_interface { - std::string name() const override { @@ -78,14 +78,12 @@ struct device_adjacent_difference_benchmark : public config_autotune_interface + std::string(Traits::name()) + ",cfg:" + config_name() + "}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - - void run(benchmark::State& state, - const std::size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using output_type = T; static constexpr bool debug_synchronous = false; @@ -96,87 +94,40 @@ struct device_adjacent_difference_benchmark : public config_autotune_interface const std::vector input = get_random_data(size, random_range.first, random_range.second, seed.get_0()); - T* d_input; - output_type* d_output = nullptr; - HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(input[0]))); - HIP_CHECK(hipMemcpy(d_input, - input.data(), - input.size() * sizeof(input[0]), - hipMemcpyHostToDevice)); + common::device_ptr d_input(input); + common::device_ptr d_output; - if ROCPRIM_IF_CONSTEXPR(Aliasing == common::api_variant::no_alias) + if constexpr(Aliasing == common::api_variant::no_alias) { - HIP_CHECK(hipMalloc(&d_output, size * sizeof(output_type))); + d_output.resize(size); } static constexpr auto left_tag = rocprim::detail::bool_constant{}; static constexpr auto alias_tag = std::integral_constant{}; // Allocate temporary storage - std::size_t temp_storage_size; - void* d_temp_storage = nullptr; + std::size_t temp_storage_size; + common::device_ptr d_temp_storage; const auto launch = [&] { return common::dispatch_adjacent_difference(left_tag, alias_tag, - d_temp_storage, + d_temp_storage.get(), temp_storage_size, - d_input, - d_output, + d_input.get(), + d_output.get(), size, rocprim::plus<>{}, stream, debug_synchronous); }; HIP_CHECK(launch()); - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size)); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(launch()); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - // Run - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) - { - HIP_CHECK(launch()); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + d_temp_storage.resize(temp_storage_size); - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * batch_size * size); + state.run([&] { HIP_CHECK(launch()); }); - HIP_CHECK(hipFree(d_input)); - if ROCPRIM_IF_CONSTEXPR(Aliasing == common::api_variant::no_alias) - { - HIP_CHECK(hipFree(d_output)); - } - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(T)); } }; @@ -195,7 +146,7 @@ struct device_adjacent_difference_benchmark_generator struct create_ipt { template - auto operator()(std::vector>& storage) + auto operator()(std::vector>& storage) -> std::enable_if_t<(ipt_num < max_items_per_thread_arg)> { using generated_config = rocprim::adjacent_difference_config; @@ -206,12 +157,12 @@ struct device_adjacent_difference_benchmark_generator } template - auto operator()(std::vector>&) + auto operator()(std::vector>&) -> std::enable_if_t {} }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { static_for_each, create_ipt>(storage); } diff --git a/benchmark/benchmark_device_adjacent_find.cpp b/benchmark/benchmark_device_adjacent_find.cpp index cedd694d7..a090798e2 100644 --- a/benchmark/benchmark_device_adjacent_find.cpp +++ b/benchmark/benchmark_device_adjacent_find.cpp @@ -22,15 +22,11 @@ #include "benchmark_device_adjacent_find.parallel.hpp" #include "benchmark_utils.hpp" -#include "cmdparser.hpp" #ifndef BENCHMARK_CONFIG_TUNING #include "../common/utils_custom_type.hpp" #endif -// gbench -#include - // HIP #include @@ -46,15 +42,7 @@ #include #endif -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = size_t{2} << 30; // 2 GiB -#endif - -#define CREATE_BENCHMARK(T, P) \ - { \ - const device_adjacent_find_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ - } +#define CREATE_BENCHMARK(T, P) executor.queue_instance(device_adjacent_find_benchmark()); #define CREATE_ADJACENT_FIND_BENCHMARKS(T) \ CREATE_BENCHMARK(T, 1) \ @@ -63,56 +51,9 @@ const size_t DEFAULT_BYTES = size_t{2} << 30; // 2 GiB int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of input bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default + benchmark_utils::executor executor(argc, argv, 2 * benchmark_utils::GiB, 10, 5); - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks{}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - size, - seed, - stream); -#else // BENCHMARK_CONFIG_TUNING \ - // add_adjacent_find_benchmarks(benchmarks, size, seed, stream); +#ifndef BENCHMARK_CONFIG_TUNING using custom_float2 = common::custom_type; using custom_double2 = common::custom_type; using custom_int2 = common::custom_type; @@ -129,31 +70,14 @@ int main(int argc, char* argv[]) CREATE_ADJACENT_FIND_BENCHMARKS(double) CREATE_ADJACENT_FIND_BENCHMARKS(rocprim::int128_t) CREATE_ADJACENT_FIND_BENCHMARKS(rocprim::uint128_t) + // Custom types CREATE_ADJACENT_FIND_BENCHMARKS(custom_float2) CREATE_ADJACENT_FIND_BENCHMARKS(custom_double2) CREATE_ADJACENT_FIND_BENCHMARKS(custom_int2) CREATE_ADJACENT_FIND_BENCHMARKS(custom_char_double) CREATE_ADJACENT_FIND_BENCHMARKS(custom_longlong_double) -#endif // BENCHMARK_CONFIG_TUNING - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } +#endif - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_adjacent_find.parallel.cpp.in b/benchmark/benchmark_device_adjacent_find.parallel.cpp.in index 6a5ac95a0..5d4303708 100644 --- a/benchmark/benchmark_device_adjacent_find.parallel.cpp.in +++ b/benchmark/benchmark_device_adjacent_find.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,7 @@ #include namespace { - auto benchmarks = config_autotune_register::create_bulk( + auto unused = benchmark_utils::executor::queue_autotune( device_adjacent_find_benchmark_generator< @InputType@, @BlockSize@>::create); diff --git a/benchmark/benchmark_device_adjacent_find.parallel.hpp b/benchmark/benchmark_device_adjacent_find.parallel.hpp index 36e7fbed2..f8cd30dbe 100644 --- a/benchmark/benchmark_device_adjacent_find.parallel.hpp +++ b/benchmark/benchmark_device_adjacent_find.parallel.hpp @@ -66,7 +66,7 @@ inline std::string config_name() template -struct device_adjacent_find_benchmark : public config_autotune_interface +struct device_adjacent_find_benchmark : public benchmark_utils::autotune_interface { std::string name() const override @@ -79,14 +79,12 @@ struct device_adjacent_find_benchmark : public config_autotune_interface + ",cfg:" + config_name() + "}"); } - static constexpr size_t warmup_size = 5; - static constexpr size_t batch_size = 10; - - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using input_type = InputT; using output_type = std::size_t; @@ -161,48 +159,9 @@ struct device_adjacent_find_benchmark : public config_autotune_interface launch_adjacent_find(); HIP_CHECK(hipMalloc(&d_tmp_storage, tmp_storage_size)); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - launch_adjacent_find(); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - // Run - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) - { - launch_adjacent_find(); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - HIP_CHECK(hipGetLastError()); - HIP_CHECK(hipDeviceSynchronize()); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + state.run([&] { launch_adjacent_find(); }); - state.SetBytesProcessed(state.iterations() * batch_size * first_adj_index - * sizeof(*d_input)); - state.SetItemsProcessed(state.iterations() * batch_size * first_adj_index); + state.set_throughput(first_adj_index, sizeof(input_type)); HIP_CHECK(hipFree(d_input)); HIP_CHECK(hipFree(d_output)); @@ -225,7 +184,8 @@ struct device_adjacent_find_benchmark_generator static constexpr unsigned int items_per_thread = 1u << ItemsPerThreadExp; using generated_config = rocprim::adjacent_find_config; - void operator()(std::vector>& storage) + void operator()( + std::vector>& storage) { storage.emplace_back( std::make_unique>()); } }; - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { static_for_each< make_index_range, @@ -241,7 +201,7 @@ struct device_adjacent_find_benchmark_generator } }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { static_for_each, create_pos>(storage); } diff --git a/benchmark/benchmark_device_batch_memcpy.cpp b/benchmark/benchmark_device_batch_memcpy.cpp index 7a5ed166a..d469b488f 100644 --- a/benchmark/benchmark_device_batch_memcpy.cpp +++ b/benchmark/benchmark_device_batch_memcpy.cpp @@ -21,11 +21,10 @@ // SOFTWARE. #include "benchmark_utils.hpp" -#include "cmdparser.hpp" #include "../common/device_batch_memcpy.hpp" +#include "../common/utils_device_ptr.hpp" -#include #include // rocPRIM @@ -50,10 +49,7 @@ #include #include -constexpr uint32_t warmup_size = 5; -constexpr int32_t max_size = 1024 * 1024; - -using offset_type = size_t; +using namespace std::string_literals; template struct BatchMemcpyData { size_t total_num_elements = 0; - ValueType* d_input = nullptr; - ValueType* d_output = nullptr; - ValueType** d_buffer_srcs = nullptr; - ValueType** d_buffer_dsts = nullptr; - BufferSizeType* d_buffer_sizes = nullptr; + common::device_ptr d_input; + common::device_ptr d_output; + common::device_ptr d_buffer_srcs; + common::device_ptr d_buffer_dsts; + common::device_ptr d_buffer_sizes; BatchMemcpyData() = default; BatchMemcpyData(const BatchMemcpyData&) = delete; - BatchMemcpyData(BatchMemcpyData&& other) - : total_num_elements{std::exchange(other.total_num_elements, 0)} - , d_input{std::exchange(other.d_input, nullptr)} - , d_output{std::exchange(other.d_output, nullptr)} - , d_buffer_srcs{std::exchange(other.d_buffer_srcs, nullptr)} - , d_buffer_dsts{std::exchange(other.d_buffer_dsts, nullptr)} - , d_buffer_sizes{std::exchange(other.d_buffer_sizes, nullptr)} - {} + BatchMemcpyData(BatchMemcpyData&& other) = default; - BatchMemcpyData& operator=(BatchMemcpyData&& other) - { - total_num_elements = std::exchange(other.total_num_elements, 0); - d_input = std::exchange(other.d_input, nullptr); - d_output = std::exchange(other.d_output, nullptr); - d_buffer_srcs = std::exchange(other.d_buffer_srcs, nullptr); - d_buffer_dsts = std::exchange(other.d_buffer_dsts, nullptr); - d_buffer_sizes = std::exchange(other.d_buffer_sizes, nullptr); - return *this; - }; + BatchMemcpyData& operator=(BatchMemcpyData&& other) = default; BatchMemcpyData& operator=(const BatchMemcpyData&) = delete; @@ -139,22 +119,15 @@ struct BatchMemcpyData return total_num_elements * sizeof(ValueType); } - ~BatchMemcpyData() - { - HIP_CHECK(hipFree(d_buffer_sizes)); - HIP_CHECK(hipFree(d_buffer_srcs)); - HIP_CHECK(hipFree(d_buffer_dsts)); - HIP_CHECK(hipFree(d_output)); - HIP_CHECK(hipFree(d_input)); - } + ~BatchMemcpyData() {} }; template BatchMemcpyData prepare_data(hipStream_t stream, const managed_seed& seed, - const int32_t num_tlev_buffers = 1024, - const int32_t num_wlev_buffers = 1024, - const int32_t num_blev_buffers = 1024) + const int32_t num_tlev_buffers, + const int32_t num_wlev_buffers, + const int32_t num_blev_buffers) { const bool shuffle_buffers = false; @@ -180,6 +153,7 @@ BatchMemcpyData prepare_data(hipStream_t stre const int32_t wlev_min_elems = rocprim::detail::ceiling_div(wlev_min_size, sizeof(ValueType)); const int32_t blev_min_elems = rocprim::detail::ceiling_div(blev_min_size, sizeof(ValueType)); + constexpr int32_t max_size = 1024 * 1024; constexpr int32_t max_elems = max_size / sizeof(ValueType); // Generate data @@ -214,12 +188,14 @@ BatchMemcpyData prepare_data(hipStream_t stre rng, result.total_num_elements * sizeof(ValueType)); - HIP_CHECK(hipMalloc(&result.d_input, result.total_num_bytes())); - HIP_CHECK(hipMalloc(&result.d_output, result.total_num_bytes())); + result.d_input.resize(result.total_num_elements); + result.d_output.resize(result.total_num_elements); + + result.d_buffer_srcs.resize(num_buffers); + result.d_buffer_dsts.resize(num_buffers); + result.d_buffer_sizes.resize(num_buffers); - HIP_CHECK(hipMalloc(&result.d_buffer_srcs, num_buffers * sizeof(ValueType*))); - HIP_CHECK(hipMalloc(&result.d_buffer_dsts, num_buffers * sizeof(ValueType*))); - HIP_CHECK(hipMalloc(&result.d_buffer_sizes, num_buffers * sizeof(BufferSizeType))); + using offset_type = size_t; // Generate the source and shuffled destination offsets. std::vector src_offsets; @@ -251,120 +227,77 @@ BatchMemcpyData prepare_data(hipStream_t stre for(size_t i = 0; i < num_buffers; ++i) { - h_buffer_srcs[i] = result.d_input + src_offsets[i]; - h_buffer_dsts[i] = result.d_output + dst_offsets[i]; + h_buffer_srcs[i] = result.d_input.get() + src_offsets[i]; + h_buffer_dsts[i] = result.d_output.get() + dst_offsets[i]; } // Prepare the batch memcpy. if(IsMemCpy) { - HIP_CHECK(hipMemcpy(result.d_input, - h_input_for_memcpy.data(), - result.total_num_bytes(), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(result.d_buffer_sizes, - h_buffer_num_bytes.data(), - h_buffer_num_bytes.size() * sizeof(BufferSizeType), - hipMemcpyHostToDevice)); + using cast_value_type = typename decltype(result.d_input)::value_type; + result.d_input.store(std::vector( + reinterpret_cast(h_input_for_memcpy.data()), + reinterpret_cast(h_input_for_memcpy.data()) + + result.total_num_elements)); + result.d_buffer_sizes.store(h_buffer_num_bytes); } else { - HIP_CHECK(hipMemcpy(result.d_input, - h_input_for_copy.data(), - result.total_num_bytes(), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(result.d_buffer_sizes, - h_buffer_num_elements.data(), - h_buffer_num_elements.size() * sizeof(BufferSizeType), - hipMemcpyHostToDevice)); + result.d_input.store( + decltype(h_input_for_copy)(h_input_for_copy.data(), + h_input_for_copy.data() + result.total_num_elements)); + result.d_buffer_sizes.store(h_buffer_num_elements); } - HIP_CHECK(hipMemcpy(result.d_buffer_srcs, - h_buffer_srcs.data(), - h_buffer_srcs.size() * sizeof(ValueType*), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(result.d_buffer_dsts, - h_buffer_dsts.data(), - h_buffer_dsts.size() * sizeof(ValueType*), - hipMemcpyHostToDevice)); + result.d_buffer_srcs.store(h_buffer_srcs); + result.d_buffer_dsts.store(h_buffer_dsts); return result; } -template -void run_benchmark(benchmark::State& state, - const managed_seed& seed, - hipStream_t stream, - const int32_t num_tlev_buffers = 1024, - const int32_t num_wlev_buffers = 1024, - const int32_t num_blev_buffers = 1024) +template +void run_benchmark(benchmark_utils::state&& state) { - const size_t num_buffers = num_tlev_buffers + num_wlev_buffers + num_blev_buffers; + const auto& stream = state.stream; + const auto& seed = state.seed; + + constexpr size_t num_buffers = NumTlevBuffers + NumWlevBuffers + NumBlevBuffers; size_t temp_storage_bytes = 0; BatchMemcpyData data; batch_copy(nullptr, temp_storage_bytes, - data.d_buffer_srcs, - data.d_buffer_dsts, - data.d_buffer_sizes, + data.d_buffer_srcs.get(), + data.d_buffer_dsts.get(), + data.d_buffer_sizes.get(), num_buffers, stream); - void* d_temp_storage = nullptr; - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_bytes)); + common::device_ptr d_temp_storage(temp_storage_bytes); data = prepare_data(stream, seed, - num_tlev_buffers, - num_wlev_buffers, - num_blev_buffers); + NumTlevBuffers, + NumWlevBuffers, + NumBlevBuffers); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - batch_copy(d_temp_storage, - temp_storage_bytes, - data.d_buffer_srcs, - data.d_buffer_dsts, - data.d_buffer_sizes, - num_buffers, - stream); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - batch_copy(d_temp_storage, - temp_storage_bytes, - data.d_buffer_srcs, - data.d_buffer_dsts, - data.d_buffer_sizes, - num_buffers, - stream); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - state.SetBytesProcessed(state.iterations() * data.total_num_bytes()); - state.SetItemsProcessed(state.iterations() * data.total_num_elements); - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - HIP_CHECK(hipFree(d_temp_storage)); + state.run( + [&] + { + batch_copy(d_temp_storage.get(), + temp_storage_bytes, + data.d_buffer_srcs.get(), + data.d_buffer_dsts.get(), + data.d_buffer_sizes.get(), + num_buffers, + stream); + }); + + state.set_throughput(data.total_num_elements, sizeof(ValueType)); } // Naive implementation used for comparison @@ -406,197 +339,135 @@ void naive_kernel(void** in_ptr, void** out_ptr, const OffsetType* sizes) } } -template -void run_naive_benchmark(benchmark::State& state, - const managed_seed& seed, - hipStream_t stream, - const int32_t num_tlev_buffers = 1024, - const int32_t num_wlev_buffers = 1024, - const int32_t num_blev_buffers = 1024) +template +void run_naive_benchmark(benchmark_utils::state&& state) { - const size_t num_buffers = num_tlev_buffers + num_wlev_buffers + num_blev_buffers; + const auto& stream = state.stream; + const auto& seed = state.seed; const auto data = prepare_data(stream, seed, - num_tlev_buffers, - num_wlev_buffers, - num_blev_buffers); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - naive_kernel - <<>>((void**)data.d_buffer_srcs, - (void**)data.d_buffer_dsts, - data.d_buffer_sizes); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); + NumTlevBuffers, + NumWlevBuffers, + NumBlevBuffers); - naive_kernel - <<>>((void**)data.d_buffer_srcs, - (void**)data.d_buffer_dsts, - data.d_buffer_sizes); + constexpr size_t num_buffers = NumTlevBuffers + NumWlevBuffers + NumBlevBuffers; - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - state.SetBytesProcessed(state.iterations() * data.total_num_bytes()); - state.SetItemsProcessed(state.iterations() * data.total_num_elements); + state.run( + [&] + { + naive_kernel + <<>>((void**)data.d_buffer_srcs.get(), + (void**)data.d_buffer_dsts.get(), + data.d_buffer_sizes.get()); + }); - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + state.set_throughput(data.total_num_elements, sizeof(ValueType)); } - #define CREATE_NAIVE_BENCHMARK(item_size, \ - item_alignment, \ - size_type, \ - num_tlev, \ - num_wlev, \ - num_blev) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name( \ - "{lvl:device,item_size:" #item_size ",item_alignment:" #item_alignment \ - ",size_type:" #size_type ",algo:naive_memcpy,num_tlev:" #num_tlev \ - ",num_wlev:" #num_wlev ",num_blev:" #num_blev ",cfg:default_config}") \ - .c_str(), \ - [=](benchmark::State& state) \ - { \ - run_naive_benchmark, \ - size_type, \ - true>(state, seed, stream, num_tlev, num_wlev, num_blev); \ - }) + #define CREATE_NAIVE_BENCHMARK(item_size, \ + item_alignment, \ + size_type, \ + num_tlev, \ + num_wlev, \ + num_blev) \ + executor.queue_fn( \ + bench_naming::format_name( \ + "{lvl:device,item_size:" #item_size ",item_alignment:" #item_alignment \ + ",size_type:" #size_type ",algo:naive_memcpy,num_tlev:" #num_tlev \ + ",num_wlev:" #num_wlev ",num_blev:" #num_blev ",cfg:default_config}") \ + .c_str(), \ + [=](benchmark_utils::state&& state) \ + { \ + run_naive_benchmark, \ + size_type, \ + true, \ + num_tlev, \ + num_wlev, \ + num_blev>(std::forward(state)); \ + }); #endif // BUILD_NAIVE_BENCHMARK -#define CREATE_BENCHMARK(item_size, item_alignment, size_type, num_tlev, num_wlev, num_blev) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:device,item_size:" #item_size \ - ",item_alignment:" #item_alignment ",size_type:" #size_type \ - ",algo:batch_memcpy,num_tlev:" #num_tlev ",num_wlev:" #num_wlev \ - ",num_blev:" #num_blev ",cfg:default_config}") \ - .c_str(), \ - [=](benchmark::State& state) \ - { \ - run_benchmark, size_type, true>( \ - state, \ - seed, \ - stream, \ - num_tlev, \ - num_wlev, \ - num_blev); \ - run_benchmark, size_type, false>( \ - state, \ - seed, \ - stream, \ - num_tlev, \ - num_wlev, \ - num_blev); \ - }) +#define CREATE_BENCHMARK(item_size, \ + item_alignment, \ + size_type, \ + num_tlev, \ + num_wlev, \ + num_blev, \ + is_memcpy) \ + executor.queue_fn(bench_naming::format_name("{lvl:device,item_size:" #item_size \ + ",item_alignment:" #item_alignment \ + ",size_type:" #size_type ",algo:" \ + + (is_memcpy ? "batch_memcpy"s : "batch_copy"s) \ + + ",num_tlev:" #num_tlev ",num_wlev:" #num_wlev \ + ",num_blev:" #num_blev ",cfg:default_config}") \ + .c_str(), \ + [=](benchmark_utils::state&& state) \ + { \ + run_benchmark, \ + size_type, \ + is_memcpy, \ + num_tlev, \ + num_wlev, \ + num_blev>(std::forward(state)); \ + }); + +#define CREATE_NORMAL_BENCHMARK(item_size, \ + item_alignment, \ + size_type, \ + num_tlev, \ + num_wlev, \ + num_blev) \ + CREATE_BENCHMARK(item_size, item_alignment, size_type, num_tlev, num_wlev, num_blev, true) \ + CREATE_BENCHMARK(item_size, item_alignment, size_type, num_tlev, num_wlev, num_blev, false) #ifndef BUILD_NAIVE_BENCHMARK - #define BENCHMARK_TYPE(item_size, item_alignment) \ - CREATE_BENCHMARK(item_size, item_alignment, uint32_t, 100000, 0, 0), \ - CREATE_BENCHMARK(item_size, item_alignment, uint32_t, 0, 100000, 0), \ - CREATE_BENCHMARK(item_size, item_alignment, uint32_t, 0, 0, 1000), \ - CREATE_BENCHMARK(item_size, item_alignment, uint32_t, 1000, 1000, 1000), \ - CREATE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 100000, 0, 0), \ - CREATE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 0, 100000, 0), \ - CREATE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 0, 0, 1000), \ - CREATE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 1000, 1000, 1000) + #define BENCHMARK_TYPE(item_size, item_alignment) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, uint32_t, 100000, 0, 0) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, uint32_t, 0, 100000, 0) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, uint32_t, 0, 0, 1000) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, uint32_t, 1000, 1000, 1000) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 100000, 0, 0) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 0, 100000, 0) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 0, 0, 1000) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 1000, 1000, 1000) #else #define BENCHMARK_TYPE(item_size, item_alignment) \ - CREATE_BENCHMARK(item_size, item_alignment, uint32_t, 100000, 0, 0), \ - CREATE_BENCHMARK(item_size, item_alignment, uint32_t, 0, 100000, 0), \ - CREATE_BENCHMARK(item_size, item_alignment, uint32_t, 0, 0, 1000), \ - CREATE_BENCHMARK(item_size, item_alignment, uint32_t, 1000, 1000, 1000), \ - CREATE_NAIVE_BENCHMARK(item_size, item_alignment, uint32_t, 100000, 0, 0), \ - CREATE_NAIVE_BENCHMARK(item_size, item_alignment, uint32_t, 0, 100000, 0), \ - CREATE_NAIVE_BENCHMARK(item_size, item_alignment, uint32_t, 0, 0, 1000), \ - CREATE_NAIVE_BENCHMARK(item_size, item_alignment, uint32_t, 1000, 1000, 1000), \ - CREATE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 100000, 0, 0), \ - CREATE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 0, 100000, 0), \ - CREATE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 0, 0, 1000), \ - CREATE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 1000, 1000, 1000), \ - CREATE_NAIVE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 100000, 0, 0), \ - CREATE_NAIVE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 0, 100000, 0), \ - CREATE_NAIVE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 0, 0, 1000), \ - CREATE_NAIVE_BENCHMARK(item_size, \ - item_alignment, \ - rocprim::uint128_t, \ - 1000, \ - 1000, \ - 1000) + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, uint32_t, 100000, 0, 0) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, uint32_t, 0, 100000, 0) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, uint32_t, 0, 0, 1000) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, uint32_t, 1000, 1000, 1000) \ + CREATE_NAIVE_BENCHMARK(item_size, item_alignment, uint32_t, 100000, 0, 0) \ + CREATE_NAIVE_BENCHMARK(item_size, item_alignment, uint32_t, 0, 100000, 0) \ + CREATE_NAIVE_BENCHMARK(item_size, item_alignment, uint32_t, 0, 0, 1000) \ + CREATE_NAIVE_BENCHMARK(item_size, item_alignment, uint32_t, 1000, 1000, 1000) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 100000, 0, 0) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 0, 100000, 0) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 0, 0, 1000) \ + CREATE_NORMAL_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 1000, 1000, 1000) \ + CREATE_NAIVE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 100000, 0, 0) \ + CREATE_NAIVE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 0, 100000, 0) \ + CREATE_NAIVE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 0, 0, 1000) \ + CREATE_NAIVE_BENCHMARK(item_size, item_alignment, rocprim::uint128_t, 1000, 1000, 1000) #endif //BUILD_NAIVE_BENCHMARK int32_t main(int32_t argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", 1024, "number of values"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); - const int32_t trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = hipStreamDefault; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {BENCHMARK_TYPE(1, 1), - BENCHMARK_TYPE(1, 2), - BENCHMARK_TYPE(1, 4), - BENCHMARK_TYPE(1, 8), - BENCHMARK_TYPE(2, 2), - BENCHMARK_TYPE(4, 4), - BENCHMARK_TYPE(8, 8)}; - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } + benchmark_utils::executor executor(argc, argv, 0, 1, 5); - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + BENCHMARK_TYPE(1, 1) + BENCHMARK_TYPE(1, 2) + BENCHMARK_TYPE(1, 4) + BENCHMARK_TYPE(1, 8) + BENCHMARK_TYPE(2, 2) + BENCHMARK_TYPE(4, 4) + BENCHMARK_TYPE(8, 8) - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_binary_search.cpp b/benchmark/benchmark_device_binary_search.cpp index 97a0010a2..336ade647 100644 --- a/benchmark/benchmark_device_binary_search.cpp +++ b/benchmark/benchmark_device_binary_search.cpp @@ -23,13 +23,9 @@ #include "benchmark_device_binary_search.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #include "../common/utils_custom_type.hpp" - -// Google Benchmark -#include +#include "../common/utils_device_ptr.hpp" // HIP API #include @@ -49,242 +45,46 @@ #include #endif -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -const unsigned int batch_size = 10; -const unsigned int warmup_size = 5; - -template -void run_benchmark(benchmark::State& state, - size_t haystack_bytes, - const managed_seed& seed, - hipStream_t stream, - size_t needles_bytes, - bool sorted_needles) -{ - using haystack_type = T; - using needle_type = T; - using output_type = size_t; - using compare_op_type = - typename std::conditional::value, - half_less, - rocprim::less>::type; - - // Calculate the number of elements from byte size - size_t haystack_size = haystack_bytes / sizeof(haystack_type); - size_t needles_size = needles_bytes / sizeof(needle_type); - - compare_op_type compare_op; - // Generate data - std::vector haystack(haystack_size); - std::iota(haystack.begin(), haystack.end(), 0); - - const auto random_range = limit_random_range(0, haystack_size); - - std::vector needles = get_random_data(needles_size, - random_range.first, - random_range.second, - seed.get_0()); - if(sorted_needles) - { - std::sort(needles.begin(), needles.end(), compare_op); - } - - haystack_type* d_haystack; - needle_type* d_needles; - output_type* d_output; - HIP_CHECK( - hipMalloc(reinterpret_cast(&d_haystack), haystack_size * sizeof(haystack_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_needles), needles_size * sizeof(needle_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), needles_size * sizeof(output_type))); - HIP_CHECK(hipMemcpy(d_haystack, - haystack.data(), - haystack_size * sizeof(haystack_type), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(d_needles, - needles.data(), - needles_size * sizeof(needle_type), - hipMemcpyHostToDevice)); - - void* d_temporary_storage = nullptr; - size_t temporary_storage_bytes; - auto dispatch_helper = dispatch_binary_search_helper(); - HIP_CHECK(dispatch_helper.dispatch_binary_search(AlgorithmSelectorTag{}, - d_temporary_storage, - temporary_storage_bytes, - d_haystack, - d_needles, - d_output, - haystack_size, - needles_size, - compare_op, - stream)); - - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(dispatch_helper.dispatch_binary_search(AlgorithmSelectorTag{}, - d_temporary_storage, - temporary_storage_bytes, - d_haystack, - d_needles, - d_output, - haystack_size, - needles_size, - compare_op, - stream)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) - { - HIP_CHECK(dispatch_helper.dispatch_binary_search(AlgorithmSelectorTag{}, - d_temporary_storage, - temporary_storage_bytes, - d_haystack, - d_needles, - d_output, - haystack_size, - needles_size, - compare_op, - stream)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * needles_size * sizeof(needle_type)); - state.SetItemsProcessed(state.iterations() * batch_size * needles_size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_haystack)); - HIP_CHECK(hipFree(d_needles)); - HIP_CHECK(hipFree(d_output)); -} - -#define CREATE_BENCHMARK(T, K, SORTED, ALGO_TAG) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name( \ - "{lvl:device,algo:" + ALGO_TAG{}.name() + ",key_type:" #T ",subalgo:" #K "_percent_" \ - + std::string(SORTED ? "sorted" : "random") + "_needles,cfg:default_config}") \ - .c_str(), \ - [=](benchmark::State& state) \ - { run_benchmark(state, bytes, seed, stream, bytes * K / 100, SORTED); }) - -#define BENCHMARK_ALGORITHMS(T, K, SORTED) \ - CREATE_BENCHMARK(T, K, SORTED, binary_search_subalgorithm), \ - CREATE_BENCHMARK(T, K, SORTED, lower_bound_subalgorithm), \ - CREATE_BENCHMARK(T, K, SORTED, upper_bound_subalgorithm) - -#define BENCHMARK_TYPE(type) \ - BENCHMARK_ALGORITHMS(type, 10, true), BENCHMARK_ALGORITHMS(type, 10, false) +#define CREATE_BENCHMARK(T, K, SORTED, SUBALGORITHM) \ + executor.queue_fn( \ + bench_naming::format_name("{lvl:device,algo:" + SUBALGORITHM{}.name() \ + + ",key_type:" #T ",subalgo:" #K "_percent_" \ + + std::string(SORTED ? "sorted" : "random") \ + + "_needles,cfg:default_config}") \ + .c_str(), \ + [=](benchmark_utils::state&& state) \ + { \ + device_binary_search_benchmark().run( \ + std::forward(state)); \ + }); + +#define BENCHMARK_ALGORITHMS(T, K, SORTED) \ + CREATE_BENCHMARK(T, K, SORTED, binary_search_subalgorithm) \ + CREATE_BENCHMARK(T, K, SORTED, lower_bound_subalgorithm) \ + CREATE_BENCHMARK(T, K, SORTED, upper_bound_subalgorithm) + +#define BENCHMARK_TYPE(type) \ + BENCHMARK_ALGORITHMS(type, 10, true) \ + BENCHMARK_ALGORITHMS(type, 10, false) int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // Add benchmarks - std::vector benchmarks; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else // BENCHMARK_CONFIG_TUNING +#ifndef BENCHMARK_CONFIG_TUNING using custom_float2 = common::custom_type; using custom_double2 = common::custom_type; - benchmarks = {BENCHMARK_TYPE(float), - BENCHMARK_TYPE(double), - BENCHMARK_TYPE(int8_t), - BENCHMARK_TYPE(uint8_t), - BENCHMARK_TYPE(rocprim::half), - BENCHMARK_TYPE(rocprim::int128_t), - BENCHMARK_TYPE(rocprim::uint128_t), - BENCHMARK_TYPE(custom_float2), - BENCHMARK_TYPE(custom_double2)}; -#endif // BENCHMARK_CONFIG_TUNING - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + BENCHMARK_TYPE(float) + BENCHMARK_TYPE(double) + BENCHMARK_TYPE(int8_t) + BENCHMARK_TYPE(uint8_t) + BENCHMARK_TYPE(rocprim::half) + BENCHMARK_TYPE(rocprim::int128_t) + BENCHMARK_TYPE(rocprim::uint128_t) + BENCHMARK_TYPE(custom_float2) + BENCHMARK_TYPE(custom_double2) +#endif - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_binary_search.parallel.cpp.in b/benchmark/benchmark_device_binary_search.parallel.cpp.in index 913cdb695..292deb89a 100644 --- a/benchmark/benchmark_device_binary_search.parallel.cpp.in +++ b/benchmark/benchmark_device_binary_search.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -30,9 +30,11 @@ namespace { -auto benchmarks = config_autotune_register::create>>(); } diff --git a/benchmark/benchmark_device_binary_search.parallel.hpp b/benchmark/benchmark_device_binary_search.parallel.hpp index 8e1544ce4..9c1a33a4d 100644 --- a/benchmark/benchmark_device_binary_search.parallel.hpp +++ b/benchmark/benchmark_device_binary_search.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,8 @@ #include "benchmark_utils.hpp" +#include "../common/utils_device_ptr.hpp" + #include #include #include @@ -111,121 +113,102 @@ struct dispatch_binary_search_helper } }; -template -struct device_binary_search_benchmark : public config_autotune_interface +template +std::string binary_search_config_name() +{ + return "{bs:" + std::to_string(Config::block_size) + + ",ipt:" + std::to_string(Config::items_per_thread) + "}"; +} + +template<> +inline std::string binary_search_config_name() +{ + return "default_config"; +} + +template +struct device_binary_search_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { return bench_naming::format_name("{lvl:device,algo:" + SubAlgorithm{}.name() + ",value_type:" + std::string(Traits::name()) + ",output_type:" + std::string(Traits::name()) - + ",cfg:{bs:" + std::to_string(Config::block_size) - + ",ipt:" + std::to_string(Config::items_per_thread) - + "}}"); + + ",cfg:" + binary_search_config_name() + "}"); } - void run(benchmark::State& state, - size_t haystack_size, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { - using compare_op_t = rocprim::less; - const auto needles_size = haystack_size / 10; - compare_op_t compare_op; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + const auto& stream = state.stream; + + size_t needles_bytes = bytes * K / 100; + + using compare_op_type = typename std:: + conditional::value, half_less, rocprim::less>::type; + // Calculate the number of elements from byte size + size_t haystack_size = bytes / sizeof(T); + size_t needles_size = needles_bytes / sizeof(T); + + compare_op_type compare_op; + + // Generate data std::vector haystack(haystack_size); std::iota(haystack.begin(), haystack.end(), 0); - const auto random_range = limit_random_range(0, haystack_size); - std::vector needles = get_random_data(needles_size, + const auto random_range = limit_random_range(0, haystack_size); + + std::vector needles = get_random_data(needles_size, random_range.first, random_range.second, seed.get_0()); - T* d_haystack; - T* d_needles; - OutputType* d_output; - HIP_CHECK(hipMalloc(&d_haystack, haystack_size * sizeof(*d_haystack))); - HIP_CHECK(hipMalloc(&d_needles, needles_size * sizeof(*d_needles))); - HIP_CHECK(hipMalloc(&d_output, needles_size * sizeof(*d_output))); - HIP_CHECK(hipMemcpy(d_haystack, - haystack.data(), - haystack_size * sizeof(*d_haystack), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(d_needles, - needles.data(), - needles_size * sizeof(*d_needles), - hipMemcpyHostToDevice)); - - void* d_temporary_storage = nullptr; - size_t temporary_storage_bytes; - auto dispatch_helper = dispatch_binary_search_helper(); - HIP_CHECK(dispatch_helper.dispatch_binary_search(SubAlgorithm{}, - d_temporary_storage, - temporary_storage_bytes, - d_haystack, - d_needles, - d_output, - haystack_size, - needles_size, - compare_op, - stream)); + if(SortedNeedles) + { + std::sort(needles.begin(), needles.end(), compare_op); + } - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + common::device_ptr d_haystack(haystack); + common::device_ptr d_needles(needles); + common::device_ptr d_output(needles_size); - // Warm-up + size_t temporary_storage_bytes; + auto dispatch_helper = dispatch_binary_search_helper(); HIP_CHECK(dispatch_helper.dispatch_binary_search(SubAlgorithm{}, - d_temporary_storage, + nullptr, temporary_storage_bytes, - d_haystack, - d_needles, - d_output, + d_haystack.get(), + d_needles.get(), + d_output.get(), haystack_size, needles_size, compare_op, stream)); - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - HIP_CHECK(dispatch_helper.dispatch_binary_search(SubAlgorithm{}, - d_temporary_storage, - temporary_storage_bytes, - d_haystack, - d_needles, - d_output, - haystack_size, - needles_size, - compare_op, - stream)); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * needles_size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * needles_size); - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_haystack)); - HIP_CHECK(hipFree(d_needles)); - HIP_CHECK(hipFree(d_output)); + common::device_ptr d_temporary_storage(temporary_storage_bytes); + + state.run( + [&] + { + HIP_CHECK(dispatch_helper.dispatch_binary_search(SubAlgorithm{}, + d_temporary_storage.get(), + temporary_storage_bytes, + d_haystack.get(), + d_needles.get(), + d_output.get(), + haystack_size, + needles_size, + compare_op, + stream)); + }); + + state.set_throughput(needles_size, sizeof(T)); } }; diff --git a/benchmark/benchmark_device_find_end.cpp b/benchmark/benchmark_device_find_end.cpp index 67506d2d0..f6281a862 100644 --- a/benchmark/benchmark_device_find_end.cpp +++ b/benchmark/benchmark_device_find_end.cpp @@ -23,14 +23,8 @@ #include "benchmark_device_find_end.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - #include "../common/utils_custom_type.hpp" -// Google Benchmark -#include - // HIP API #include @@ -41,59 +35,23 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif +#define CREATE_BENCHMARK_FIND_END(TYPE, KEY_SIZE, REPEATING) \ + executor.queue_instance(device_find_end_benchmark(KEY_SIZE, REPEATING)); -#define CREATE_BENCHMARK_FIND_END(TYPE, KEY_SIZE, REPEATING) \ - { \ - const device_find_end_benchmark instance(KEY_SIZE, REPEATING); \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } +#define CREATE_BENCHMARK_PATTERN(TYPE, REPEATING) \ + CREATE_BENCHMARK_FIND_END(TYPE, 10, REPEATING) \ + CREATE_BENCHMARK_FIND_END(TYPE, 100, REPEATING) \ + CREATE_BENCHMARK_FIND_END(TYPE, 1000, REPEATING) \ + CREATE_BENCHMARK_FIND_END(TYPE, 10000, REPEATING) -#define CREATE_BENCHMARK_PATTERN(TYPE, REPEATING) \ - { \ - CREATE_BENCHMARK_FIND_END(TYPE, 10, REPEATING) \ - CREATE_BENCHMARK_FIND_END(TYPE, 100, REPEATING) \ - CREATE_BENCHMARK_FIND_END(TYPE, 1000, REPEATING) \ - CREATE_BENCHMARK_FIND_END(TYPE, 10000, REPEATING) \ - } - -#define CREATE_BENCHMARK(TYPE) \ - { \ - CREATE_BENCHMARK_PATTERN(TYPE, true) CREATE_BENCHMARK_PATTERN(TYPE, false) \ - } +#define CREATE_BENCHMARK(TYPE) \ + CREATE_BENCHMARK_PATTERN(TYPE, true) \ + CREATE_BENCHMARK_PATTERN(TYPE, false) int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("bytes", "bytes", DEFAULT_BYTES, "number of values"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("bytes"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks{}; CREATE_BENCHMARK(int) CREATE_BENCHMARK(long long) CREATE_BENCHMARK(int8_t) @@ -116,23 +74,5 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(custom_char_double) CREATE_BENCHMARK(custom_longlong_double) - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_find_end.hpp b/benchmark/benchmark_device_find_end.hpp index 6a4e207e5..8260f93c9 100644 --- a/benchmark/benchmark_device_find_end.hpp +++ b/benchmark/benchmark_device_find_end.hpp @@ -26,6 +26,7 @@ #include "benchmark_utils.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -44,7 +45,7 @@ #include template -struct device_find_end_benchmark : public config_autotune_interface +struct device_find_end_benchmark : public benchmark_utils::autotune_interface { size_t key_size_ = 10; bool repeating_ = false; @@ -64,14 +65,12 @@ struct device_find_end_benchmark : public config_autotune_interface + ",value_type:" + std::string(Traits::name()) + ",cfg:default_config}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; using output_type = size_t; @@ -105,98 +104,43 @@ struct device_find_end_benchmark : public config_autotune_interface seed.get_0() + 1); } - key_type* d_keys_input; - key_type* d_input; - output_type* d_output; - HIP_CHECK(hipMalloc(&d_keys_input, key_size * sizeof(*d_keys_input))); - HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); - HIP_CHECK(hipMalloc(&d_output, sizeof(*d_output))); - - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); - - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - key_size * sizeof(*d_keys_input), - hipMemcpyHostToDevice)); + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_input(input); + common::device_ptr d_output(1); rocprim::equal_to compare_op; - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::find_end(d_temporary_storage, + HIP_CHECK(rocprim::find_end(nullptr, temporary_storage_bytes, - d_input, - d_keys_input, - d_output, + d_input.get(), + d_keys_input.get(), + d_output.get(), size, key_size, compare_op, stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::find_end(d_temporary_storage, - temporary_storage_bytes, - d_input, - d_keys_input, - d_output, - size, - key_size, - compare_op, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); + common::device_ptr d_temporary_storage(temporary_storage_bytes); - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::find_end(d_temporary_storage, + HIP_CHECK(rocprim::find_end(d_temporary_storage.get(), temporary_storage_bytes, - d_input, - d_keys_input, - d_output, + d_input.get(), + d_keys_input.get(), + d_output.get(), size, key_size, compare_op, stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(*d_input)); - state.SetItemsProcessed(state.iterations() * batch_size * size); + }); - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + state.set_throughput(size, sizeof(key_type)); } }; diff --git a/benchmark/benchmark_device_find_first_of.cpp b/benchmark/benchmark_device_find_first_of.cpp index f31f4fe0d..afb2ec1c3 100644 --- a/benchmark/benchmark_device_find_first_of.cpp +++ b/benchmark/benchmark_device_find_first_of.cpp @@ -23,16 +23,10 @@ #include "benchmark_device_find_first_of.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - #ifndef BENCHMARK_CONFIG_TUNING #include "../common/utils_custom_type.hpp" #endif -// Google Benchmark -#include - // HIP API #include @@ -47,85 +41,28 @@ #include #endif -#ifndef DEFAULT_BYTES -constexpr size_t DEFAULT_BYTES = size_t{1} << 27; // 128 MiB -#endif - -#define CREATE_BENCHMARK_FIND_FIRST_OF(TYPE, KEYS_SIZE, FIRST_OCCURENCE) \ - { \ - const device_find_first_of_benchmark instance(KEYS_SIZE, FIRST_OCCURENCE); \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ - } +#define CREATE_BENCHMARK_FIND_FIRST_OF(TYPE, KEYS_SIZE, FIRST_OCCURENCE) \ + executor.queue_instance(device_find_first_of_benchmark(KEYS_SIZE, FIRST_OCCURENCE)); // clang-format off -#define CREATE_BENCHMARK0(TYPE, KEYS_SIZE) \ - { \ - CREATE_BENCHMARK_FIND_FIRST_OF(TYPE, KEYS_SIZE, 0.1) \ - CREATE_BENCHMARK_FIND_FIRST_OF(TYPE, KEYS_SIZE, 0.5) \ - CREATE_BENCHMARK_FIND_FIRST_OF(TYPE, KEYS_SIZE, 1.0) \ - } - -#define CREATE_BENCHMARK(TYPE) \ - { \ - CREATE_BENCHMARK0(TYPE, 1) \ - CREATE_BENCHMARK0(TYPE, 10) \ - CREATE_BENCHMARK0(TYPE, 100) \ - CREATE_BENCHMARK0(TYPE, 1000) \ - CREATE_BENCHMARK0(TYPE, 10000) \ - } +#define CREATE_BENCHMARK0(TYPE, KEYS_SIZE) \ + CREATE_BENCHMARK_FIND_FIRST_OF(TYPE, KEYS_SIZE, 0.1) \ + CREATE_BENCHMARK_FIND_FIRST_OF(TYPE, KEYS_SIZE, 0.5) \ + CREATE_BENCHMARK_FIND_FIRST_OF(TYPE, KEYS_SIZE, 1.0) + +#define CREATE_BENCHMARK(TYPE) \ + CREATE_BENCHMARK0(TYPE, 1) \ + CREATE_BENCHMARK0(TYPE, 10) \ + CREATE_BENCHMARK0(TYPE, 100) \ + CREATE_BENCHMARK0(TYPE, 1000) \ + CREATE_BENCHMARK0(TYPE, 10000) // clang-format on int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 2); - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks{}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - size, - seed, - stream); -#else // BENCHMARK_CONFIG_TUNING +#ifndef BENCHMARK_CONFIG_TUNING CREATE_BENCHMARK(int8_t) CREATE_BENCHMARK(int16_t) CREATE_BENCHMARK(int32_t) @@ -140,25 +77,7 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(custom_int2) CREATE_BENCHMARK(custom_longlong_double) -#endif // BENCHMARK_CONFIG_TUNING - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } +#endif - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_find_first_of.parallel.cpp.in b/benchmark/benchmark_device_find_first_of.parallel.cpp.in index 5bf3acfe0..e2149cc04 100644 --- a/benchmark/benchmark_device_find_first_of.parallel.cpp.in +++ b/benchmark/benchmark_device_find_first_of.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -29,6 +29,6 @@ namespace { -auto benchmarks = config_autotune_register::create_bulk( +auto unused = benchmark_utils::executor::queue_autotune( device_find_first_of_benchmark_generator<@DataType@, @BlockSize@>::create); } // namespace diff --git a/benchmark/benchmark_device_find_first_of.parallel.hpp b/benchmark/benchmark_device_find_first_of.parallel.hpp index b6cbac768..8e8f225e7 100644 --- a/benchmark/benchmark_device_find_first_of.parallel.hpp +++ b/benchmark/benchmark_device_find_first_of.parallel.hpp @@ -60,7 +60,7 @@ inline std::string config_name() } template -struct device_find_first_of_benchmark : public config_autotune_interface +struct device_find_first_of_benchmark : public benchmark_utils::autotune_interface { std::vector keys_sizes; std::vector first_occurrences; @@ -91,14 +91,12 @@ struct device_find_first_of_benchmark : public config_autotune_interface + "}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 2; - - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using type = T; using key_type = T; using output_type = size_t; @@ -175,30 +173,8 @@ struct device_find_first_of_benchmark : public config_autotune_interface temporary_storage_bytes = max_temporary_storage_bytes; HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - for(size_t fi = 0; fi < first_occurrences.size(); ++fi) - { - for(size_t keys_size : keys_sizes) - { - run(keys_size, d_inputs[fi]); - } - } - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { for(size_t fi = 0; fi < first_occurrences.size(); ++fi) { @@ -207,20 +183,7 @@ struct device_find_first_of_benchmark : public config_autotune_interface run(keys_size, d_inputs[fi]); } } - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + }); // Only a part of data (before the first occurrence) must be actually processed. In ideal // cases when no thread blocks do unneeded work (i.e. exit early once the match is found), @@ -235,17 +198,17 @@ struct device_find_first_of_benchmark : public config_autotune_interface { sum_keys_size += keys_size; } - state.SetBytesProcessed(state.iterations() * batch_size * sum_effective_size - * sizeof(*d_inputs[0])); - state.SetItemsProcessed(state.iterations() * batch_size * sum_effective_size); + + state.set_throughput(sum_effective_size, sizeof(type)); + // Each input is read once but all keys are read by all threads so performance is likely // compute-bound or bound by cache bandwidth for reading keys rather than reading inputs. // Let's additionally report the rate of comparisons to see if it reaches a plateau with // increasing keys_size. - state.counters["comparisons_per_second"] - = benchmark::Counter(static_cast(state.iterations() * batch_size - * sum_effective_size * sum_keys_size), - benchmark::Counter::kIsRate); + state.gbench_state.counters["comparisons_per_second"] = benchmark::Counter( + static_cast(state.gbench_state.iterations() * state.batch_iterations + * sum_effective_size * sum_keys_size), + benchmark::Counter::kIsRate); for(size_t fi = 0; fi < first_occurrences.size(); ++fi) { @@ -266,7 +229,7 @@ struct device_find_first_of_benchmark_generator { using generated_config = rocprim::find_first_of_config; - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { std::vector keys_sizes{1, 10, 100, 1000}; std::vector first_occurrences{0.1, 0.5, 1.0}; @@ -277,7 +240,7 @@ struct device_find_first_of_benchmark_generator } }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { static constexpr unsigned int min_items_per_thread = 1; static constexpr unsigned int max_items_per_thread = 16; diff --git a/benchmark/benchmark_device_histogram.cpp b/benchmark/benchmark_device_histogram.cpp index 01e26449e..173c190f5 100644 --- a/benchmark/benchmark_device_histogram.cpp +++ b/benchmark/benchmark_device_histogram.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -23,8 +23,7 @@ #include "benchmark_device_histogram.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -41,13 +40,6 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -const unsigned int batch_size = 10; -const unsigned int warmup_size = 5; - int get_entropy_percents(int entropy_reduction) { switch(entropy_reduction) @@ -61,17 +53,15 @@ int get_entropy_percents(int entropy_reduction) } } -const int entropy_reductions[] = {0, 2, 4, 6}; - template -void run_even_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed&, - hipStream_t stream, - size_t bins, - size_t scale, - int entropy_reduction) +void run_even_benchmark(benchmark_utils::state&& state, + size_t bins, + size_t scale, + int entropy_reduction) { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + // Calculate the number of elements size_t size = bytes / sizeof(T); @@ -85,98 +75,51 @@ void run_even_benchmark(benchmark::State& state, // Generate data std::vector input = generate(size, entropy_reduction, lower_level, upper_level); - T* d_input; - counter_type* d_histogram; - HIP_CHECK(hipMalloc(&d_input, size * sizeof(T))); - HIP_CHECK(hipMalloc(&d_histogram, size * sizeof(counter_type))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); + common::device_ptr d_histogram(size); - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::histogram_even(d_temporary_storage, + HIP_CHECK(rocprim::histogram_even(nullptr, temporary_storage_bytes, - d_input, + d_input.get(), size, - d_histogram, + d_histogram.get(), bins + 1, lower_level, upper_level, stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + common::device_ptr d_temporary_storage(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::histogram_even(d_temporary_storage, - temporary_storage_bytes, - d_input, - size, - d_histogram, - bins + 1, - lower_level, - upper_level, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::histogram_even(d_temporary_storage, + HIP_CHECK(rocprim::histogram_even(d_temporary_storage.get(), temporary_storage_bytes, - d_input, + d_input.get(), size, - d_histogram, + d_histogram.get(), bins + 1, lower_level, upper_level, stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + }); - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_histogram)); + state.set_throughput(size, sizeof(T)); } template -void run_multi_even_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed&, - hipStream_t stream, - size_t bins, - size_t scale, - int entropy_reduction) +void run_multi_even_benchmark(benchmark_utils::state&& state, + size_t bins, + size_t scale, + int entropy_reduction) { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + // Calculate the number of elements size_t size = bytes / sizeof(T); @@ -198,20 +141,17 @@ void run_multi_even_benchmark(benchmark::State& state, std::vector input = generate(size * Channels, entropy_reduction, lower_level[0], upper_level[0]); - T* d_input; - counter_type* d_histogram[ActiveChannels]; - HIP_CHECK(hipMalloc(&d_input, size * Channels * sizeof(T))); + common::device_ptr d_input(input); + counter_type* d_histogram[ActiveChannels]; for(unsigned int channel = 0; channel < ActiveChannels; ++channel) { HIP_CHECK(hipMalloc(&d_histogram[channel], bins * sizeof(counter_type))); } - HIP_CHECK(hipMemcpy(d_input, input.data(), size * Channels * sizeof(T), hipMemcpyHostToDevice)); - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK((rocprim::multi_histogram_even(d_temporary_storage, + HIP_CHECK((rocprim::multi_histogram_even(nullptr, temporary_storage_bytes, - d_input, + d_input.get(), size, d_histogram, num_levels, @@ -220,41 +160,16 @@ void run_multi_even_benchmark(benchmark::State& state, stream, false))); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + common::device_ptr d_temporary_storage(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK((rocprim::multi_histogram_even(d_temporary_storage, - temporary_storage_bytes, - d_input, - size, - d_histogram, - num_levels, - lower_level, - upper_level, - stream, - false))); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { HIP_CHECK( - (rocprim::multi_histogram_even(d_temporary_storage, + (rocprim::multi_histogram_even(d_temporary_storage.get(), temporary_storage_bytes, - d_input, + d_input.get(), size, d_histogram, num_levels, @@ -262,26 +177,10 @@ void run_multi_even_benchmark(benchmark::State& state, upper_level, stream, false))); - } + }); - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); + state.set_throughput(size * Channels, sizeof(T)); - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * Channels * sizeof(T)); - state.SetItemsProcessed(state.iterations() * batch_size * size * Channels); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_input)); for(unsigned int channel = 0; channel < ActiveChannels; ++channel) { HIP_CHECK(hipFree(d_histogram[channel])); @@ -289,12 +188,12 @@ void run_multi_even_benchmark(benchmark::State& state, } template -void run_range_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream, - size_t bins) +void run_range_benchmark(benchmark_utils::state&& state, size_t bins) { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(T); @@ -313,98 +212,48 @@ void run_range_benchmark(benchmark::State& state, levels[i] = static_cast(i); } - T* d_input; - level_type* d_levels; - counter_type* d_histogram; - HIP_CHECK(hipMalloc(&d_input, size * sizeof(T))); - HIP_CHECK(hipMalloc(&d_levels, (bins + 1) * sizeof(level_type))); - HIP_CHECK(hipMalloc(&d_histogram, size * sizeof(counter_type))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); - HIP_CHECK( - hipMemcpy(d_levels, levels.data(), (bins + 1) * sizeof(level_type), hipMemcpyHostToDevice)); - - void* d_temporary_storage = nullptr; + common::device_ptr d_input(input); + common::device_ptr d_levels(levels); + common::device_ptr d_histogram(size); + size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::histogram_range(d_temporary_storage, + HIP_CHECK(rocprim::histogram_range(nullptr, temporary_storage_bytes, - d_input, + d_input.get(), size, - d_histogram, + d_histogram.get(), bins + 1, - d_levels, + d_levels.get(), stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + common::device_ptr d_temporary_storage(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::histogram_range(d_temporary_storage, - temporary_storage_bytes, - d_input, - size, - d_histogram, - bins + 1, - d_levels, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::histogram_range(d_temporary_storage, + HIP_CHECK(rocprim::histogram_range(d_temporary_storage.get(), temporary_storage_bytes, - d_input, + d_input.get(), size, - d_histogram, + d_histogram.get(), bins + 1, - d_levels, + d_levels.get(), stream, false)); - } + }); - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_levels)); - HIP_CHECK(hipFree(d_histogram)); + state.set_throughput(size, sizeof(T)); } template -void run_multi_range_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream, - size_t bins) +void run_multi_range_benchmark(benchmark_utils::state&& state, size_t bins) { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(T); @@ -432,17 +281,15 @@ void run_multi_range_benchmark(benchmark::State& state, random_range.second, seed.get_0()); - T* d_input; + common::device_ptr d_input(input); level_type* d_levels[ActiveChannels]; - counter_type* d_histogram[ActiveChannels]; - HIP_CHECK(hipMalloc(&d_input, size * Channels * sizeof(T))); + counter_type* d_histogram[ActiveChannels]; for(unsigned int channel = 0; channel < ActiveChannels; ++channel) { HIP_CHECK(hipMalloc(&d_levels[channel], num_levels_channel * sizeof(level_type))); HIP_CHECK(hipMalloc(&d_histogram[channel], size * sizeof(counter_type))); } - HIP_CHECK(hipMemcpy(d_input, input.data(), size * Channels * sizeof(T), hipMemcpyHostToDevice)); for(unsigned int channel = 0; channel < ActiveChannels; ++channel) { HIP_CHECK(hipMemcpy(d_levels[channel], @@ -451,11 +298,10 @@ void run_multi_range_benchmark(benchmark::State& state, hipMemcpyHostToDevice)); } - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK((rocprim::multi_histogram_range(d_temporary_storage, + HIP_CHECK((rocprim::multi_histogram_range(nullptr, temporary_storage_bytes, - d_input, + d_input.get(), size, d_histogram, num_levels, @@ -463,66 +309,26 @@ void run_multi_range_benchmark(benchmark::State& state, stream, false))); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + common::device_ptr d_temporary_storage(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK((rocprim::multi_histogram_range(d_temporary_storage, - temporary_storage_bytes, - d_input, - size, - d_histogram, - num_levels, - d_levels, - stream, - false))); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { HIP_CHECK( - (rocprim::multi_histogram_range(d_temporary_storage, + (rocprim::multi_histogram_range(d_temporary_storage.get(), temporary_storage_bytes, - d_input, + d_input.get(), size, d_histogram, num_levels, d_levels, stream, false))); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + }); - state.SetBytesProcessed(state.iterations() * batch_size * size * Channels * sizeof(T)); - state.SetItemsProcessed(state.iterations() * batch_size * size * Channels); + state.set_throughput(size * Channels, sizeof(T)); - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_input)); for(unsigned int channel = 0; channel < ActiveChannels; ++channel) { HIP_CHECK(hipFree(d_levels[channel])); @@ -530,244 +336,145 @@ void run_multi_range_benchmark(benchmark::State& state, } } -#define CREATE_EVEN_BENCHMARK(VECTOR, T, BINS, SCALE) \ - VECTOR.push_back(benchmark::RegisterBenchmark( \ +#define CREATE_EVEN_BENCHMARK(T, BINS, SCALE) \ + executor.queue_fn( \ bench_naming::format_name("{lvl:device,algo:histogram_even,value_type:" #T ",entropy:" \ + std::to_string(get_entropy_percents(entropy_reduction)) \ + ",bins:" + std::to_string(BINS) + ",cfg:default_config}") \ .c_str(), \ - [=](benchmark::State& state) \ - { run_even_benchmark(state, bytes, seed, stream, BINS, SCALE, entropy_reduction); })); - -#define BENCHMARK_EVEN_TYPE(VECTOR, T, S) \ - CREATE_EVEN_BENCHMARK(VECTOR, T, 10, S); \ - CREATE_EVEN_BENCHMARK(VECTOR, T, 100, S); \ - CREATE_EVEN_BENCHMARK(VECTOR, T, 1000, S); \ - CREATE_EVEN_BENCHMARK(VECTOR, T, 10000, S); - -void add_even_benchmarks(std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) -{ - for(int entropy_reduction : entropy_reductions) - { - BENCHMARK_EVEN_TYPE(benchmarks, long long, 12345); - BENCHMARK_EVEN_TYPE(benchmarks, int, 1234); - BENCHMARK_EVEN_TYPE(benchmarks, short, 5); - CREATE_EVEN_BENCHMARK(benchmarks, unsigned char, 16, 16); - CREATE_EVEN_BENCHMARK(benchmarks, unsigned char, 256, 1); - BENCHMARK_EVEN_TYPE(benchmarks, double, 1234); - BENCHMARK_EVEN_TYPE(benchmarks, float, 1234); - BENCHMARK_EVEN_TYPE(benchmarks, rocprim::half, 5); - CREATE_EVEN_BENCHMARK(benchmarks, rocprim::int128_t, 16, 16); - CREATE_EVEN_BENCHMARK(benchmarks, rocprim::int128_t, 256, 1); - CREATE_EVEN_BENCHMARK(benchmarks, rocprim::uint128_t, 16, 16); - CREATE_EVEN_BENCHMARK(benchmarks, rocprim::uint128_t, 256, 1); - }; -} - -#define CREATE_MULTI_EVEN_BENCHMARK(CHANNELS, ACTIVE_CHANNELS, T, BINS, SCALE) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:device,algo:multi_histogram_even,value_type:" #T \ - ",channels:" #CHANNELS ",active_channels:" #ACTIVE_CHANNELS \ - ",entropy:" \ - + std::to_string(get_entropy_percents(entropy_reduction)) \ - + ",bins:" + std::to_string(BINS) + ",cfg:default_config}") \ - .c_str(), \ - [=](benchmark::State& state) \ - { \ - run_multi_even_benchmark(state, \ - bytes, \ - seed, \ - stream, \ - BINS, \ - SCALE, \ - entropy_reduction); \ - }) - -// clang-format off -#define BENCHMARK_MULTI_EVEN_TYPE(C, A, T, S) \ - CREATE_MULTI_EVEN_BENCHMARK(C, A, T, 10, S), \ - CREATE_MULTI_EVEN_BENCHMARK(C, A, T, 100, S), \ - CREATE_MULTI_EVEN_BENCHMARK(C, A, T, 1000, S), \ + [=](benchmark_utils::state&& state) \ + { \ + run_even_benchmark(std::forward(state), \ + BINS, \ + SCALE, \ + entropy_reduction); \ + }); + +#define BENCHMARK_EVEN_TYPE(T, S) \ + CREATE_EVEN_BENCHMARK(T, 10, S) \ + CREATE_EVEN_BENCHMARK(T, 100, S) \ + CREATE_EVEN_BENCHMARK(T, 1000, S) \ + CREATE_EVEN_BENCHMARK(T, 10000, S) + +#define CREATE_MULTI_EVEN_BENCHMARK(CHANNELS, ACTIVE_CHANNELS, T, BINS, SCALE) \ + executor.queue_fn(bench_naming::format_name( \ + "{lvl:device,algo:multi_histogram_even,value_type:" #T \ + ",channels:" #CHANNELS ",active_channels:" #ACTIVE_CHANNELS ",entropy:" \ + + std::to_string(get_entropy_percents(entropy_reduction)) \ + + ",bins:" + std::to_string(BINS) + ",cfg:default_config}") \ + .c_str(), \ + [=](benchmark_utils::state&& state) \ + { \ + run_multi_even_benchmark( \ + std::forward(state), \ + BINS, \ + SCALE, \ + entropy_reduction); \ + }); + +#define BENCHMARK_MULTI_EVEN_TYPE(C, A, T, S) \ + CREATE_MULTI_EVEN_BENCHMARK(C, A, T, 10, S) \ + CREATE_MULTI_EVEN_BENCHMARK(C, A, T, 100, S) \ + CREATE_MULTI_EVEN_BENCHMARK(C, A, T, 1000, S) \ CREATE_MULTI_EVEN_BENCHMARK(C, A, T, 10000, S) -// clang-format on - -void add_multi_even_benchmarks(std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) -{ - for(int entropy_reduction : entropy_reductions) - { - std::vector bs - = {BENCHMARK_MULTI_EVEN_TYPE(4, 4, int, 1234), - BENCHMARK_MULTI_EVEN_TYPE(4, 3, short, 5), - CREATE_MULTI_EVEN_BENCHMARK(4, 3, unsigned char, 16, 16), - CREATE_MULTI_EVEN_BENCHMARK(4, 3, unsigned char, 256, 1), - BENCHMARK_MULTI_EVEN_TYPE(3, 3, float, 1234), - CREATE_MULTI_EVEN_BENCHMARK(4, 3, rocprim::int128_t, 16, 16), - CREATE_MULTI_EVEN_BENCHMARK(4, 3, rocprim::int128_t, 256, 1), - CREATE_MULTI_EVEN_BENCHMARK(4, 3, rocprim::uint128_t, 16, 16), - CREATE_MULTI_EVEN_BENCHMARK(4, 3, rocprim::uint128_t, 256, 1)}; - benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); - }; -} #define CREATE_RANGE_BENCHMARK(T, BINS) \ - benchmark::RegisterBenchmark( \ + executor.queue_fn( \ bench_naming::format_name("{lvl:device,algo:histogram_range,value_type:" #T ",bins:" \ + std::to_string(BINS) + ",cfg:default_config}") \ .c_str(), \ - [=](benchmark::State& state) \ - { run_range_benchmark(state, bytes, seed, stream, BINS); }) - -// clang-format off -#define BENCHMARK_RANGE_TYPE(T) \ - CREATE_RANGE_BENCHMARK(T, 10), \ - CREATE_RANGE_BENCHMARK(T, 100), \ - CREATE_RANGE_BENCHMARK(T, 1000), \ - CREATE_RANGE_BENCHMARK(T, 10000) -// clang-format on + [=](benchmark_utils::state&& state) \ + { run_range_benchmark(std::forward(state), BINS); }); -void add_range_benchmarks(std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) -{ - std::vector bs - = {BENCHMARK_RANGE_TYPE(long long), - BENCHMARK_RANGE_TYPE(int), - BENCHMARK_RANGE_TYPE(short), - CREATE_RANGE_BENCHMARK(unsigned char, 16), - CREATE_RANGE_BENCHMARK(unsigned char, 256), - BENCHMARK_RANGE_TYPE(double), - BENCHMARK_RANGE_TYPE(float), - BENCHMARK_RANGE_TYPE(rocprim::half), - CREATE_RANGE_BENCHMARK(rocprim::int128_t, 16), - CREATE_RANGE_BENCHMARK(rocprim::int128_t, 256), - CREATE_RANGE_BENCHMARK(rocprim::uint128_t, 16), - CREATE_RANGE_BENCHMARK(rocprim::uint128_t, 256)}; - benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); -} +#define BENCHMARK_RANGE_TYPE(T) \ + CREATE_RANGE_BENCHMARK(T, 10) \ + CREATE_RANGE_BENCHMARK(T, 100) \ + CREATE_RANGE_BENCHMARK(T, 1000) \ + CREATE_RANGE_BENCHMARK(T, 10000) -#define CREATE_MULTI_RANGE_BENCHMARK(CHANNELS, ACTIVE_CHANNELS, T, BINS) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:device,algo:multi_histogram_range,value_type:" #T \ - ",channels:" #CHANNELS ",active_channels:" #ACTIVE_CHANNELS \ - ",bins:" \ - + std::to_string(BINS) + ",cfg:default_config}") \ - .c_str(), \ - [=](benchmark::State& state) { \ - run_multi_range_benchmark(state, \ - bytes, \ - seed, \ - stream, \ - BINS); \ - }) - -// clang-format off -#define BENCHMARK_MULTI_RANGE_TYPE(C, A, T) \ - CREATE_MULTI_RANGE_BENCHMARK(C, A, T, 10), \ - CREATE_MULTI_RANGE_BENCHMARK(C, A, T, 100), \ - CREATE_MULTI_RANGE_BENCHMARK(C, A, T, 1000), \ +#define CREATE_MULTI_RANGE_BENCHMARK(CHANNELS, ACTIVE_CHANNELS, T, BINS) \ + executor.queue_fn(bench_naming::format_name( \ + "{lvl:device,algo:multi_histogram_range,value_type:" #T \ + ",channels:" #CHANNELS ",active_channels:" #ACTIVE_CHANNELS ",bins:" \ + + std::to_string(BINS) + ",cfg:default_config}") \ + .c_str(), \ + [=](benchmark_utils::state&& state) \ + { \ + run_multi_range_benchmark( \ + std::forward(state), \ + BINS); \ + }); + +#define BENCHMARK_MULTI_RANGE_TYPE(C, A, T) \ + CREATE_MULTI_RANGE_BENCHMARK(C, A, T, 10) \ + CREATE_MULTI_RANGE_BENCHMARK(C, A, T, 100) \ + CREATE_MULTI_RANGE_BENCHMARK(C, A, T, 1000) \ CREATE_MULTI_RANGE_BENCHMARK(C, A, T, 10000) -// clang-format on - -void add_multi_range_benchmarks(std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) -{ - std::vector bs - = {BENCHMARK_MULTI_RANGE_TYPE(4, 4, int), - BENCHMARK_MULTI_RANGE_TYPE(4, 3, short), - CREATE_MULTI_RANGE_BENCHMARK(4, 3, unsigned char, 16), - CREATE_MULTI_RANGE_BENCHMARK(4, 3, unsigned char, 256), - BENCHMARK_MULTI_RANGE_TYPE(3, 3, float), - BENCHMARK_MULTI_RANGE_TYPE(2, 2, double), - CREATE_MULTI_RANGE_BENCHMARK(4, 3, rocprim::int128_t, 16), - CREATE_MULTI_RANGE_BENCHMARK(4, 3, rocprim::int128_t, 256), - CREATE_MULTI_RANGE_BENCHMARK(4, 3, rocprim::uint128_t, 16), - CREATE_MULTI_RANGE_BENCHMARK(4, 3, rocprim::uint128_t, 256)}; - benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); -} int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else // BENCHMARK_CONFIG_TUNING - add_even_benchmarks(benchmarks, bytes, seed, stream); - add_multi_even_benchmarks(benchmarks, bytes, seed, stream); - add_range_benchmarks(benchmarks, bytes, seed, stream); - add_multi_range_benchmarks(benchmarks, bytes, seed, stream); -#endif // BENCHMARK_CONFIG_TUNING - - // Use manual timing - for(auto& b : benchmarks) + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); + +#ifndef BENCHMARK_CONFIG_TUNING + const int entropy_reductions[] = {0, 2, 4, 6}; + + // Even benchmarks + for(int entropy_reduction : entropy_reductions) { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); + BENCHMARK_EVEN_TYPE(long long, 12345) + BENCHMARK_EVEN_TYPE(int, 1234) + BENCHMARK_EVEN_TYPE(short, 5) + CREATE_EVEN_BENCHMARK(unsigned char, 16, 16) + CREATE_EVEN_BENCHMARK(unsigned char, 256, 1) + BENCHMARK_EVEN_TYPE(double, 1234) + BENCHMARK_EVEN_TYPE(float, 1234) + BENCHMARK_EVEN_TYPE(rocprim::half, 5) + CREATE_EVEN_BENCHMARK(rocprim::int128_t, 16, 16) + CREATE_EVEN_BENCHMARK(rocprim::int128_t, 256, 1) + CREATE_EVEN_BENCHMARK(rocprim::uint128_t, 16, 16) + CREATE_EVEN_BENCHMARK(rocprim::uint128_t, 256, 1) } - // Force number of iterations - if(trials > 0) + // Multi-even benchmarks + for(int entropy_reduction : entropy_reductions) { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } + BENCHMARK_MULTI_EVEN_TYPE(4, 4, int, 1234) + BENCHMARK_MULTI_EVEN_TYPE(4, 3, short, 5) + CREATE_MULTI_EVEN_BENCHMARK(4, 3, unsigned char, 16, 16) + CREATE_MULTI_EVEN_BENCHMARK(4, 3, unsigned char, 256, 1) + BENCHMARK_MULTI_EVEN_TYPE(3, 3, float, 1234) + CREATE_MULTI_EVEN_BENCHMARK(4, 3, rocprim::int128_t, 16, 16) + CREATE_MULTI_EVEN_BENCHMARK(4, 3, rocprim::int128_t, 256, 1) + CREATE_MULTI_EVEN_BENCHMARK(4, 3, rocprim::uint128_t, 16, 16) + CREATE_MULTI_EVEN_BENCHMARK(4, 3, rocprim::uint128_t, 256, 1) } - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + // Range benchmarks + BENCHMARK_RANGE_TYPE(long long) + BENCHMARK_RANGE_TYPE(int) + BENCHMARK_RANGE_TYPE(short) + CREATE_RANGE_BENCHMARK(unsigned char, 16) + CREATE_RANGE_BENCHMARK(unsigned char, 256) + BENCHMARK_RANGE_TYPE(double) + BENCHMARK_RANGE_TYPE(float) + BENCHMARK_RANGE_TYPE(rocprim::half) + CREATE_RANGE_BENCHMARK(rocprim::int128_t, 16) + CREATE_RANGE_BENCHMARK(rocprim::int128_t, 256) + CREATE_RANGE_BENCHMARK(rocprim::uint128_t, 16) + CREATE_RANGE_BENCHMARK(rocprim::uint128_t, 256) + + // Multi-range benchmarks + BENCHMARK_MULTI_RANGE_TYPE(4, 4, int) + BENCHMARK_MULTI_RANGE_TYPE(4, 3, short) + CREATE_MULTI_RANGE_BENCHMARK(4, 3, unsigned char, 16) + CREATE_MULTI_RANGE_BENCHMARK(4, 3, unsigned char, 256) + BENCHMARK_MULTI_RANGE_TYPE(3, 3, float) + BENCHMARK_MULTI_RANGE_TYPE(2, 2, double) + CREATE_MULTI_RANGE_BENCHMARK(4, 3, rocprim::int128_t, 16) + CREATE_MULTI_RANGE_BENCHMARK(4, 3, rocprim::int128_t, 256) + CREATE_MULTI_RANGE_BENCHMARK(4, 3, rocprim::uint128_t, 16) + CREATE_MULTI_RANGE_BENCHMARK(4, 3, rocprim::uint128_t, 256) +#endif + + executor.run(); } diff --git a/benchmark/benchmark_device_histogram.parallel.cpp.in b/benchmark/benchmark_device_histogram.parallel.cpp.in index bf45cd3c1..975d36cd8 100644 --- a/benchmark/benchmark_device_histogram.parallel.cpp.in +++ b/benchmark/benchmark_device_histogram.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,6 @@ #include namespace { - auto benchmarks = config_autotune_register::create_bulk( + auto unused = benchmark_utils::executor::queue_autotune( device_histogram_benchmark_generator<@DataType@, @BlockSize@>::create); } diff --git a/benchmark/benchmark_device_histogram.parallel.hpp b/benchmark/benchmark_device_histogram.parallel.hpp index f0151d6c4..0d034b729 100644 --- a/benchmark/benchmark_device_histogram.parallel.hpp +++ b/benchmark/benchmark_device_histogram.parallel.hpp @@ -25,6 +25,8 @@ #include "benchmark_utils.hpp" +#include "../common/utils_device_ptr.hpp" + // Google Benchmark #include @@ -40,8 +42,10 @@ #include #include #include +#include #include #include +#include #include template @@ -81,6 +85,7 @@ std::vector generate(size_t size, int entropy_reduction, int lower_level, int // Cache for input data when multiple cases must be benchmarked with various configurations and // same inputs can be used for consecutive benchmarks. // It must be used as a singleton. +template class input_cache { public: @@ -91,23 +96,17 @@ class input_cache void clear() { - for(auto& i : cache) - { - HIP_CHECK(hipFree(i.second)); - } total_cache_size = 0; cache.clear(); } - // The function returns an exisitng buffer if main_key matches and there is additional_key - // in the cache or generates a new buffer using gen(). + // The function returns an existing buffer if main_key matches and there is additional_key + // in the cache, or generates a new buffer using gen(). // If main_key does not match, it frees all device buffers and resets the cache. - template - T* get_or_generate(const std::string& main_key, - const std::string& additional_key, - size_t size, - F gen) + template + T* get_or_generate(const std::string& main_key, const std::string& additional_key, F gen) { + // Experimentally determined maximum size, before the GPU runs out of memory. static constexpr short max_default_bytes_count = 176; if(this->main_key != main_key) { @@ -119,27 +118,30 @@ class input_cache auto result = cache.find(additional_key); if(result != cache.end()) { - return reinterpret_cast(result->second); + return reinterpret_cast(result->second.get()); } // Generate a new buffer std::vector data = gen(); - T* d_buffer; + common::device_ptr d_buffer; if(total_cache_size >= max_default_bytes_count) { + // the memory space of the value of last key-value pair is held by d_buffer + // and the pair is erased from the cache map auto iter = cache.end(); --iter; - d_buffer = reinterpret_cast(iter->second); + d_buffer = std::move(iter->second); cache.erase(iter); } else { - HIP_CHECK(hipMalloc(&d_buffer, size * sizeof(T))); + // it will generate a new memory space to store in cache + // so records the new size in advance total_cache_size += sizeof(T); } - HIP_CHECK(hipMemcpy(d_buffer, data.data(), size * sizeof(T), hipMemcpyHostToDevice)); - cache[additional_key] = d_buffer; - return d_buffer; + d_buffer.store(data); + cache[additional_key] = std::move(d_buffer); + return cache[additional_key].get(); } static input_cache& instance() @@ -150,7 +152,7 @@ class input_cache private: std::string main_key; - std::map cache; + std::map> cache; short total_cache_size = 0; }; @@ -171,8 +173,11 @@ inline std::string config_name() return "default_config"; } -template -struct device_histogram_benchmark : public config_autotune_interface +template +struct device_histogram_benchmark : public benchmark_utils::autotune_interface { std::vector cases; @@ -187,14 +192,11 @@ struct device_histogram_benchmark : public config_autotune_interface + ",cfg:" + config_name() + "}"); } - static constexpr unsigned int batch_size = 3; - static constexpr unsigned int warmup_size = 5; - - void run(benchmark::State& state, - size_t full_size, - const managed_seed&, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + using counter_type = unsigned int; using level_type = typename std:: conditional_t::value && sizeof(T) < sizeof(int), int, T>; @@ -206,20 +208,18 @@ struct device_histogram_benchmark : public config_autotune_interface level_type lower_level[ActiveChannels]{}; level_type upper_level[ActiveChannels]{}; unsigned int num_levels[ActiveChannels]{}; - T* get_d_input(size_t full_size) + T* get_d_input(size_t bytes) { - return input_cache::instance().get_or_generate( + return input_cache::instance().get_or_generate( std::string(Traits::name()), std::to_string(bins) + "_" + std::to_string(entropy_reduction), - full_size, - [&]() { return generate(full_size, entropy_reduction, 0, bins); }); + [&]() { return generate(bytes, entropy_reduction, 0, bins); }); }; }; - const std::size_t size = full_size / Channels; + const std::size_t size = bytes / Channels; size_t temporary_storage_bytes = 0; - void* d_temporary_storage = nullptr; counter_type* d_histogram[ActiveChannels]; unsigned int max_bins = 0; @@ -245,9 +245,9 @@ struct device_histogram_benchmark : public config_autotune_interface size_t current_temporary_storage_bytes = 0; HIP_CHECK((rocprim::multi_histogram_even( - d_temporary_storage, + nullptr, current_temporary_storage_bytes, - data.get_d_input(full_size), + data.get_d_input(bytes), size, d_histogram, data.num_levels, @@ -262,52 +262,24 @@ struct device_histogram_benchmark : public config_autotune_interface } } - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + common::device_ptr d_temporary_storage(temporary_storage_bytes); for(unsigned int channel = 0; channel < ActiveChannels; ++channel) { HIP_CHECK(hipMalloc(&d_histogram[channel], max_bins * sizeof(counter_type))); } HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - for(auto& data : cases_data) - { - HIP_CHECK((rocprim::multi_histogram_even( - d_temporary_storage, - temporary_storage_bytes, - data.get_d_input(full_size), - size, - d_histogram, - data.num_levels, - data.lower_level, - data.upper_level, - stream, - false))); - } - } - HIP_CHECK(hipDeviceSynchronize()); + size_t total_size = 0; - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) + for(auto& data : cases_data) { - float elapsed_mseconds = 0; - for(auto& data : cases_data) - { - T* d_input = data.get_d_input(full_size); - float partial_elapsed_mseconds; - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); + T* d_input = data.get_d_input(bytes); - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { HIP_CHECK((rocprim::multi_histogram_even( - d_temporary_storage, + d_temporary_storage.get(), temporary_storage_bytes, d_input, size, @@ -317,28 +289,13 @@ struct device_histogram_benchmark : public config_autotune_interface data.upper_level, stream, false))); - } + }); - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - HIP_CHECK(hipEventElapsedTime(&partial_elapsed_mseconds, start, stop)); - elapsed_mseconds += partial_elapsed_mseconds; - } - state.SetIterationTime(elapsed_mseconds / 1000); + total_size += size * Channels; } - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * cases_data.size() * batch_size * size - * Channels * sizeof(T)); - state.SetItemsProcessed(state.iterations() * cases_data.size() * batch_size * size - * Channels); + state.set_throughput(total_size, sizeof(T)); - HIP_CHECK(hipFree(d_temporary_storage)); for(unsigned int channel = 0; channel < ActiveChannels; ++channel) { HIP_CHECK(hipFree(d_histogram[channel])); @@ -369,8 +326,8 @@ struct device_histogram_benchmark_generator template - auto create(std::vector>& storage, - const std::vector& cases) -> + auto create(std::vector>& storage, + const std::vector& cases) -> typename std::enable_if<(items_per_thread * Channels <= max_items_per_thread), void>::type { @@ -383,14 +340,16 @@ struct device_histogram_benchmark_generator template - auto create(std::vector>& /*storage*/, - const std::vector& /*cases*/) -> + auto create( + std::vector>& /*storage*/, + const std::vector& /*cases*/) -> typename std::enable_if::type {} - void operator()(std::vector>& storage, - const std::vector& cases) + void operator()( + std::vector>& storage, + const std::vector& cases) { // Tune histograms for single-channel data (histogram_even) create<1, 1>(storage, cases); @@ -402,8 +361,8 @@ struct device_histogram_benchmark_generator } }; - void operator()(std::vector>& storage, - const std::vector& cases) + void operator()(std::vector>& storage, + const std::vector& cases) { static_for_each>& storage) + static void create(std::vector>& storage) { // Benchmark multiple cases (with various sample distributions) and use sum of all cases // as a measurement for autotuning diff --git a/benchmark/benchmark_device_memory.cpp b/benchmark/benchmark_device_memory.cpp index a3480eee1..05d1387ea 100644 --- a/benchmark/benchmark_device_memory.cpp +++ b/benchmark/benchmark_device_memory.cpp @@ -21,13 +21,10 @@ // SOFTWARE. #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" -// Google Benchmark -#include // rocPRIM #include #include @@ -334,52 +331,28 @@ template -void run_benchmark(benchmark::State& state, - size_t size, - const managed_seed& seed, - const hipStream_t stream) +void run_benchmark(benchmark_utils::state&& state) { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + + const size_t size = bytes / sizeof(T); + const size_t grid_size = size / (BlockSize * ItemsPerThread); std::vector input = get_random_data(size, common::generate_limits::min(), common::generate_limits::max(), seed.get_0()); - T* d_input; - T* d_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(T))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), size * sizeof(T))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); + common::device_ptr d_output(size); HIP_CHECK(hipDeviceSynchronize()); operation selected_operation; - // Warm-up - for(size_t i = 0; i < 10; ++i) - { - hipLaunchKernelGGL(HIP_KERNEL_NAME(operation_kernel), - dim3(grid_size), - dim3(BlockSize), - 0, - stream, - d_input, - d_output, - selected_operation); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - const unsigned int batch_size = 10; - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { hipLaunchKernelGGL( HIP_KERNEL_NAME(operation_kernel), @@ -387,2818 +360,753 @@ void run_benchmark(benchmark::State& state, dim3(BlockSize), 0, stream, - d_input, - d_output, + d_input.get(), + d_output.get(), selected_operation); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); + }); - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + state.set_throughput(size, sizeof(T)); } template -void run_benchmark_memcpy(benchmark::State& state, - size_t size, - const managed_seed&, - const hipStream_t stream) +void run_benchmark_memcpy(benchmark_utils::state&& state) { + const auto& bytes = state.bytes; + + const size_t size = bytes / sizeof(T); + // Allocate device buffers // Note: since this benchmark only tests performance by memcpying between device buffers, // we don't really need to transfer data into these from the host - whatever happens // to be in device memory will do. - T* d_input; - T* d_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(T))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), size * sizeof(T))); - // Warm-up - for(size_t i = 0; i < 10; ++i) - { - HIP_CHECK(hipMemcpy(d_output, d_input, size * sizeof(T), hipMemcpyDeviceToDevice)); - } - HIP_CHECK(hipDeviceSynchronize()); + common::device_ptr d_input(size); + common::device_ptr d_output(size); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - const unsigned int batch_size = 10; - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(hipMemcpy(d_output, d_input, size * sizeof(T), hipMemcpyDeviceToDevice)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + HIP_CHECK(hipMemcpy(d_output.get(), + d_input.get(), + size * sizeof(T), + hipMemcpyDeviceToDevice)); + }); - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + state.set_throughput(size, sizeof(T)); } -#define CREATE_BENCHMARK(METHOD, OPERATION, T, SIZE, BLOCK_SIZE, IPT) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:device,algo:memory,subalgo:" #METHOD \ - ",operation:" #OPERATION ",key_type:" #T ",size:" #SIZE \ - ",cfg:{bs:" #BLOCK_SIZE ",ipt:" #IPT "}}") \ - .c_str(), \ - run_benchmark, \ - SIZE, \ - seed, \ - stream) - -#define CREATE_BENCHMARK_MEMCPY(T, SIZE) \ - benchmark::RegisterBenchmark( \ +#define CREATE_BENCHMARK(METHOD, OPERATION, T, BLOCK_SIZE, IPT) \ + executor.queue_fn(bench_naming::format_name("{lvl:device,algo:memory,subalgo:" #METHOD \ + ",operation:" #OPERATION ",key_type:" #T \ + ",cfg:{bs:" #BLOCK_SIZE ",ipt:" #IPT "}}") \ + .c_str(), \ + run_benchmark); + +#define CREATE_BENCHMARK_MEMCPY(T) \ + executor.queue_fn( \ bench_naming::format_name("{lvl:device,algo:memory,subalgo:copy,key_type:" #T \ - ",size:" #SIZE ",cfg:default_config}") \ + ",cfg:default_config}") \ .c_str(), \ - run_benchmark_memcpy, \ - SIZE, \ - seed, \ - stream) - -template -constexpr unsigned int megabytes(unsigned int size) -{ - return (size * (1024 * 1024 / sizeof(T))); -} + run_benchmark_memcpy); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = { - // simple memory copy not running kernel - CREATE_BENCHMARK_MEMCPY(int, megabytes(128)), - - // simple memory copy - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 1024, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - int, - megabytes(128), - 1024, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - uint64_t, - megabytes(128), - 1024, - 4), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 512, - 4), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::int128_t, - megabytes(128), - 1024, - 2), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 4), - - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - no_operation, - rocprim::uint128_t, - megabytes(128), - 1024, - 2), - - // simple memory copy using vector type - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 128, 1), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 128, 2), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 128, 4), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 128, 8), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 128, 16), - - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 256, 1), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 256, 2), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 256, 4), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 256, 8), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 256, 16), - - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 512, 1), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 512, 2), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 512, 4), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 512, 8), - - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 1024, 1), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 1024, 2), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 1024, 4), - CREATE_BENCHMARK(vectorized, no_operation, int, megabytes(128), 1024, 8), - - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 128, 1), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 128, 2), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 128, 4), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 128, 8), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 128, 16), - - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 256, 1), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 256, 2), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 256, 4), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 256, 8), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 256, 16), - - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 512, 1), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 512, 2), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 512, 4), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 512, 8), - - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 1024, 1), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 1024, 2), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 1024, 4), - CREATE_BENCHMARK(vectorized, no_operation, uint64_t, megabytes(128), 1024, 8), - - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 1024, - 4), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::int128_t, - megabytes(128), - 1024, - 8), - - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 1024, - 4), - CREATE_BENCHMARK(vectorized, - no_operation, - rocprim::uint128_t, - megabytes(128), - 1024, - 8), - - // simple memory copy using striped - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 128, 1), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 128, 2), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 128, 4), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 128, 8), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 128, 16), - - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 256, 1), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 256, 2), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 256, 4), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 256, 8), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 256, 16), - - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 512, 1), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 512, 2), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 512, 4), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 512, 8), - - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 1024, 1), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 1024, 2), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 1024, 4), - CREATE_BENCHMARK(striped, no_operation, int, megabytes(128), 1024, 8), - - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 128, 1), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 128, 2), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 128, 4), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 128, 8), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 128, 16), - - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 256, 1), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 256, 2), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 256, 4), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 256, 8), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 256, 16), - - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 512, 1), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 512, 2), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 512, 4), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 512, 8), - - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 1024, 1), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 1024, 2), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 1024, 4), - CREATE_BENCHMARK(striped, no_operation, uint64_t, megabytes(128), 1024, 8), - - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 1024, - 4), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::int128_t, - megabytes(128), - 1024, - 8), - - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 1024, - 4), - CREATE_BENCHMARK(striped, - no_operation, - rocprim::uint128_t, - megabytes(128), - 1024, - 8), - - // block_scan - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 128, 1), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 128, 2), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 128, 4), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 128, 8), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 128, 16), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 128, 32), - - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 256, 1), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 256, 2), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 256, 4), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 256, 8), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 256, 16), - - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 512, 1), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 512, 2), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 512, 4), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 512, 8), - - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 1024, 1), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 1024, 2), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 1024, 4), - CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, megabytes(128), 1024, 8), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 1024, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - float, - megabytes(128), - 1024, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - double, - megabytes(128), - 1024, - 4), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - uint64_t, - megabytes(128), - 1024, - 4), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 256, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 512, - 4), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::int128_t, - megabytes(128), - 1024, - 2), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 256, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 512, - 4), - - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - block_scan, - rocprim::uint128_t, - megabytes(128), - 1024, - 2), - - // vectorized - block_scan - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 128, 1), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 128, 2), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 128, 4), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 128, 8), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 128, 16), - - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 256, 1), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 256, 2), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 256, 4), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 256, 8), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 256, 16), - - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 512, 1), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 512, 2), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 512, 4), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 512, 8), - - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 1024, 1), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 1024, 2), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 1024, 4), - CREATE_BENCHMARK(vectorized, block_scan, int, megabytes(128), 1024, 8), - - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 128, 1), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 128, 2), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 128, 4), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 128, 8), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 128, 16), - - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 256, 1), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 256, 2), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 256, 4), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 256, 8), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 256, 16), - - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 512, 1), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 512, 2), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 512, 4), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 512, 8), - - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 1024, 1), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 1024, 2), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 1024, 4), - CREATE_BENCHMARK(vectorized, block_scan, float, megabytes(128), 1024, 8), - - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 128, 1), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 128, 2), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 128, 4), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 128, 8), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 128, 16), - - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 256, 1), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 256, 2), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 256, 4), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 256, 8), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 256, 16), - - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 512, 1), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 512, 2), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 512, 4), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 512, 8), - - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 1024, 1), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 1024, 2), - CREATE_BENCHMARK(vectorized, block_scan, double, megabytes(128), 1024, 4), - - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 128, 1), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 128, 2), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 128, 4), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 128, 8), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 128, 16), - - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 256, 1), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 256, 2), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 256, 4), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 256, 8), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 256, 16), - - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 512, 1), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 512, 2), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 512, 4), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 512, 8), - - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 1024, 1), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 1024, 2), - CREATE_BENCHMARK(vectorized, block_scan, uint64_t, megabytes(128), 1024, 4), - - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::int128_t, - megabytes(128), - 1024, - 4), - - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(vectorized, - block_scan, - rocprim::uint128_t, - megabytes(128), - 1024, - 4), - - // custom_op - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - int, - megabytes(128), - 1024, - 4), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - float, - megabytes(128), - 1024, - 4), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - double, - megabytes(128), - 1024, - 2), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - uint64_t, - megabytes(128), - 1024, - 2), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 256, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 512, - 4), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::int128_t, - megabytes(128), - 1024, - 2), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 256, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 512, - 4), - - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - custom_operation, - rocprim::uint128_t, - megabytes(128), - 1024, - 2), - - // block_primitives_transpose - atomics no collision - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 1024, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_no_collision, - int, - megabytes(128), - 1024, - 8), - - // block_primitives_transpose - atomics inter block collision - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 1024, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_block_collision, - int, - megabytes(128), - 1024, - 8), - - // block_primitives_transpose - atomics inter warp collision - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 128, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 128, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 128, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 128, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 128, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 256, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 256, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 256, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 256, - 8), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 256, - 16), - - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 512, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 512, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 512, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 512, - 8), - - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 1024, - 1), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 1024, - 2), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 1024, - 4), - CREATE_BENCHMARK(block_primitives_transpose, - atomics_inter_warp_collision, - int, - megabytes(128), - 1024, - 8) - - }; - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - - return 0; + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 10); + + // simple memory copy not running kernel + CREATE_BENCHMARK_MEMCPY(int) + + // simple memory copy + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 1024, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 1024, 4) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, int, 1024, 8) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 1024, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, uint64_t, 1024, 4) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 256, 8) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 512, 4) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::int128_t, 1024, 2) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 256, 8) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 512, 4) + + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, no_operation, rocprim::uint128_t, 1024, 2) + + // simple memory copy using vector type + CREATE_BENCHMARK(vectorized, no_operation, int, 128, 1) + CREATE_BENCHMARK(vectorized, no_operation, int, 128, 2) + CREATE_BENCHMARK(vectorized, no_operation, int, 128, 4) + CREATE_BENCHMARK(vectorized, no_operation, int, 128, 8) + CREATE_BENCHMARK(vectorized, no_operation, int, 128, 16) + + CREATE_BENCHMARK(vectorized, no_operation, int, 256, 1) + CREATE_BENCHMARK(vectorized, no_operation, int, 256, 2) + CREATE_BENCHMARK(vectorized, no_operation, int, 256, 4) + CREATE_BENCHMARK(vectorized, no_operation, int, 256, 8) + CREATE_BENCHMARK(vectorized, no_operation, int, 256, 16) + + CREATE_BENCHMARK(vectorized, no_operation, int, 512, 1) + CREATE_BENCHMARK(vectorized, no_operation, int, 512, 2) + CREATE_BENCHMARK(vectorized, no_operation, int, 512, 4) + CREATE_BENCHMARK(vectorized, no_operation, int, 512, 8) + + CREATE_BENCHMARK(vectorized, no_operation, int, 1024, 1) + CREATE_BENCHMARK(vectorized, no_operation, int, 1024, 2) + CREATE_BENCHMARK(vectorized, no_operation, int, 1024, 4) + CREATE_BENCHMARK(vectorized, no_operation, int, 1024, 8) + + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 128, 1) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 128, 2) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 128, 4) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 128, 8) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 128, 16) + + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 256, 1) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 256, 2) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 256, 4) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 256, 8) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 256, 16) + + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 512, 1) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 512, 2) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 512, 4) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 512, 8) + + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 1024, 1) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 1024, 2) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 1024, 4) + CREATE_BENCHMARK(vectorized, no_operation, uint64_t, 1024, 8) + + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 128, 1) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 128, 2) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 128, 4) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 128, 8) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 128, 16) + + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 256, 1) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 256, 2) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 256, 4) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 256, 8) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 256, 16) + + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 512, 1) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 512, 2) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 512, 4) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 512, 8) + + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 1024, 1) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 1024, 2) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 1024, 4) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::int128_t, 1024, 8) + + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 128, 1) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 128, 2) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 128, 4) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 128, 8) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 128, 16) + + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 256, 1) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 256, 2) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 256, 4) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 256, 8) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 256, 16) + + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 512, 1) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 512, 2) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 512, 4) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 512, 8) + + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 1024, 1) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 1024, 2) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 1024, 4) + CREATE_BENCHMARK(vectorized, no_operation, rocprim::uint128_t, 1024, 8) + + // simple memory copy using striped + CREATE_BENCHMARK(striped, no_operation, int, 128, 1) + CREATE_BENCHMARK(striped, no_operation, int, 128, 2) + CREATE_BENCHMARK(striped, no_operation, int, 128, 4) + CREATE_BENCHMARK(striped, no_operation, int, 128, 8) + CREATE_BENCHMARK(striped, no_operation, int, 128, 16) + + CREATE_BENCHMARK(striped, no_operation, int, 256, 1) + CREATE_BENCHMARK(striped, no_operation, int, 256, 2) + CREATE_BENCHMARK(striped, no_operation, int, 256, 4) + CREATE_BENCHMARK(striped, no_operation, int, 256, 8) + CREATE_BENCHMARK(striped, no_operation, int, 256, 16) + + CREATE_BENCHMARK(striped, no_operation, int, 512, 1) + CREATE_BENCHMARK(striped, no_operation, int, 512, 2) + CREATE_BENCHMARK(striped, no_operation, int, 512, 4) + CREATE_BENCHMARK(striped, no_operation, int, 512, 8) + + CREATE_BENCHMARK(striped, no_operation, int, 1024, 1) + CREATE_BENCHMARK(striped, no_operation, int, 1024, 2) + CREATE_BENCHMARK(striped, no_operation, int, 1024, 4) + CREATE_BENCHMARK(striped, no_operation, int, 1024, 8) + + CREATE_BENCHMARK(striped, no_operation, uint64_t, 128, 1) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 128, 2) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 128, 4) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 128, 8) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 128, 16) + + CREATE_BENCHMARK(striped, no_operation, uint64_t, 256, 1) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 256, 2) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 256, 4) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 256, 8) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 256, 16) + + CREATE_BENCHMARK(striped, no_operation, uint64_t, 512, 1) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 512, 2) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 512, 4) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 512, 8) + + CREATE_BENCHMARK(striped, no_operation, uint64_t, 1024, 1) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 1024, 2) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 1024, 4) + CREATE_BENCHMARK(striped, no_operation, uint64_t, 1024, 8) + + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 128, 1) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 128, 2) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 128, 4) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 128, 8) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 128, 16) + + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 256, 1) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 256, 2) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 256, 4) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 256, 8) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 256, 16) + + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 512, 1) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 512, 2) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 512, 4) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 512, 8) + + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 1024, 1) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 1024, 2) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 1024, 4) + CREATE_BENCHMARK(striped, no_operation, rocprim::int128_t, 1024, 8) + + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 128, 1) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 128, 2) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 128, 4) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 128, 8) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 128, 16) + + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 256, 1) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 256, 2) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 256, 4) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 256, 8) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 256, 16) + + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 512, 1) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 512, 2) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 512, 4) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 512, 8) + + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 1024, 1) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 1024, 2) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 1024, 4) + CREATE_BENCHMARK(striped, no_operation, rocprim::uint128_t, 1024, 8) + + // block_scan + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 128, 16) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 128, 32) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 1024, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 1024, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, int, 1024, 8) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 1024, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 1024, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, float, 1024, 8) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 1024, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, double, 1024, 4) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 1024, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, uint64_t, 1024, 4) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 256, 8) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 512, 4) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::int128_t, 1024, 2) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 256, 8) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 512, 4) + + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, block_scan, rocprim::uint128_t, 1024, 2) + + // vectorized - block_scan + CREATE_BENCHMARK(vectorized, block_scan, int, 128, 1) + CREATE_BENCHMARK(vectorized, block_scan, int, 128, 2) + CREATE_BENCHMARK(vectorized, block_scan, int, 128, 4) + CREATE_BENCHMARK(vectorized, block_scan, int, 128, 8) + CREATE_BENCHMARK(vectorized, block_scan, int, 128, 16) + + CREATE_BENCHMARK(vectorized, block_scan, int, 256, 1) + CREATE_BENCHMARK(vectorized, block_scan, int, 256, 2) + CREATE_BENCHMARK(vectorized, block_scan, int, 256, 4) + CREATE_BENCHMARK(vectorized, block_scan, int, 256, 8) + CREATE_BENCHMARK(vectorized, block_scan, int, 256, 16) + + CREATE_BENCHMARK(vectorized, block_scan, int, 512, 1) + CREATE_BENCHMARK(vectorized, block_scan, int, 512, 2) + CREATE_BENCHMARK(vectorized, block_scan, int, 512, 4) + CREATE_BENCHMARK(vectorized, block_scan, int, 512, 8) + + CREATE_BENCHMARK(vectorized, block_scan, int, 1024, 1) + CREATE_BENCHMARK(vectorized, block_scan, int, 1024, 2) + CREATE_BENCHMARK(vectorized, block_scan, int, 1024, 4) + CREATE_BENCHMARK(vectorized, block_scan, int, 1024, 8) + + CREATE_BENCHMARK(vectorized, block_scan, float, 128, 1) + CREATE_BENCHMARK(vectorized, block_scan, float, 128, 2) + CREATE_BENCHMARK(vectorized, block_scan, float, 128, 4) + CREATE_BENCHMARK(vectorized, block_scan, float, 128, 8) + CREATE_BENCHMARK(vectorized, block_scan, float, 128, 16) + + CREATE_BENCHMARK(vectorized, block_scan, float, 256, 1) + CREATE_BENCHMARK(vectorized, block_scan, float, 256, 2) + CREATE_BENCHMARK(vectorized, block_scan, float, 256, 4) + CREATE_BENCHMARK(vectorized, block_scan, float, 256, 8) + CREATE_BENCHMARK(vectorized, block_scan, float, 256, 16) + + CREATE_BENCHMARK(vectorized, block_scan, float, 512, 1) + CREATE_BENCHMARK(vectorized, block_scan, float, 512, 2) + CREATE_BENCHMARK(vectorized, block_scan, float, 512, 4) + CREATE_BENCHMARK(vectorized, block_scan, float, 512, 8) + + CREATE_BENCHMARK(vectorized, block_scan, float, 1024, 1) + CREATE_BENCHMARK(vectorized, block_scan, float, 1024, 2) + CREATE_BENCHMARK(vectorized, block_scan, float, 1024, 4) + CREATE_BENCHMARK(vectorized, block_scan, float, 1024, 8) + + CREATE_BENCHMARK(vectorized, block_scan, double, 128, 1) + CREATE_BENCHMARK(vectorized, block_scan, double, 128, 2) + CREATE_BENCHMARK(vectorized, block_scan, double, 128, 4) + CREATE_BENCHMARK(vectorized, block_scan, double, 128, 8) + CREATE_BENCHMARK(vectorized, block_scan, double, 128, 16) + + CREATE_BENCHMARK(vectorized, block_scan, double, 256, 1) + CREATE_BENCHMARK(vectorized, block_scan, double, 256, 2) + CREATE_BENCHMARK(vectorized, block_scan, double, 256, 4) + CREATE_BENCHMARK(vectorized, block_scan, double, 256, 8) + CREATE_BENCHMARK(vectorized, block_scan, double, 256, 16) + + CREATE_BENCHMARK(vectorized, block_scan, double, 512, 1) + CREATE_BENCHMARK(vectorized, block_scan, double, 512, 2) + CREATE_BENCHMARK(vectorized, block_scan, double, 512, 4) + CREATE_BENCHMARK(vectorized, block_scan, double, 512, 8) + + CREATE_BENCHMARK(vectorized, block_scan, double, 1024, 1) + CREATE_BENCHMARK(vectorized, block_scan, double, 1024, 2) + CREATE_BENCHMARK(vectorized, block_scan, double, 1024, 4) + + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 128, 1) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 128, 2) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 128, 4) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 128, 8) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 128, 16) + + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 256, 1) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 256, 2) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 256, 4) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 256, 8) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 256, 16) + + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 512, 1) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 512, 2) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 512, 4) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 512, 8) + + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 1024, 1) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 1024, 2) + CREATE_BENCHMARK(vectorized, block_scan, uint64_t, 1024, 4) + + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 128, 1) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 128, 2) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 128, 4) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 128, 8) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 128, 16) + + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 256, 1) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 256, 2) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 256, 4) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 256, 8) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 256, 16) + + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 512, 1) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 512, 2) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 512, 4) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 512, 8) + + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 1024, 1) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 1024, 2) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::int128_t, 1024, 4) + + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 128, 1) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 128, 2) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 128, 4) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 128, 8) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 128, 16) + + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 256, 1) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 256, 2) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 256, 4) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 256, 8) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 256, 16) + + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 512, 1) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 512, 2) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 512, 4) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 512, 8) + + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 1024, 1) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 1024, 2) + CREATE_BENCHMARK(vectorized, block_scan, rocprim::uint128_t, 1024, 4) + + // custom_op + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 1024, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, int, 1024, 4) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 1024, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, float, 1024, 4) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, double, 1024, 2) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, uint64_t, 1024, 2) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 256, 8) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 512, 4) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::int128_t, 1024, 2) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 256, 8) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 512, 4) + + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, custom_operation, rocprim::uint128_t, 1024, 2) + + // block_primitives_transpose - atomics no collision + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 1024, 2) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 1024, 4) + CREATE_BENCHMARK(block_primitives_transpose, atomics_no_collision, int, 1024, 8) + + // block_primitives_transpose - atomics inter block collision + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 1024, 2) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 1024, 4) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_block_collision, int, 1024, 8) + + // block_primitives_transpose - atomics inter warp collision + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 128, 1) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 128, 2) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 128, 4) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 128, 8) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 128, 16) + + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 256, 1) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 256, 2) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 256, 4) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 256, 8) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 256, 16) + + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 512, 1) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 512, 2) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 512, 4) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 512, 8) + + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 1024, 1) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 1024, 2) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 1024, 4) + CREATE_BENCHMARK(block_primitives_transpose, atomics_inter_warp_collision, int, 1024, 8) + + executor.run(); } diff --git a/benchmark/benchmark_device_merge.cpp b/benchmark/benchmark_device_merge.cpp index 3c635e4e7..4fdceba53 100644 --- a/benchmark/benchmark_device_merge.cpp +++ b/benchmark/benchmark_device_merge.cpp @@ -22,16 +22,11 @@ #include "benchmark_device_merge.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #ifndef BENCHMARK_CONFIG_TUNING #include "../common/utils_custom_type.hpp" #endif -// Google Benchmark -#include - // HIP API #include @@ -47,84 +42,17 @@ #include #endif -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_BENCHMARK(...) \ - { \ - const device_merge_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_MERGE_KEYS_BENCHMARK(Key) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:device,algo:merge,key_type:" #Key ",cfg:default_config}") \ - .c_str(), \ - [=](benchmark::State& state) \ - { run_merge_keys_benchmark(state, bytes, seed, stream); }) - -#define CREATE_MERGE_PAIRS_BENCHMARK(Key, Value) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:device,algo:merge,key_type:" #Key ",value_type:" #Value \ - ",cfg:default_config}") \ - .c_str(), \ - [=](benchmark::State& state) \ - { run_merge_pairs_benchmark(state, bytes, seed, stream); }) +#define CREATE_BENCHMARK(...) executor.queue_instance(device_merge_benchmark<__VA_ARGS__>()); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else // BENCHMARK_CONFIG_TUNING - using custom_int2 = common::custom_type; - using custom_double2 = common::custom_type; +#ifndef BENCHMARK_CONFIG_TUNING + using custom_int2 = common::custom_type; + using custom_double2 = common::custom_type; + using huge_float2_1024 = common::custom_huge_type<1024, float, float>; + using huge_float2_2048 = common::custom_huge_type<2048, float, float>; CREATE_BENCHMARK(int) CREATE_BENCHMARK(long long) @@ -136,6 +64,8 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(custom_double2) CREATE_BENCHMARK(rocprim::int128_t) CREATE_BENCHMARK(rocprim::uint128_t) + CREATE_BENCHMARK(huge_float2_1024) + CREATE_BENCHMARK(huge_float2_2048) CREATE_BENCHMARK(int, int) CREATE_BENCHMARK(long long, long long) @@ -147,26 +77,9 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(custom_double2, custom_double2) CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t) CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t) + CREATE_BENCHMARK(huge_float2_1024, huge_float2_1024) + CREATE_BENCHMARK(huge_float2_2048, huge_float2_2048) +#endif -#endif // BENCHMARK_CONFIG_TUNING - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_merge.parallel.cpp.in b/benchmark/benchmark_device_merge.parallel.cpp.in index 3147c3ac0..7caa74930 100644 --- a/benchmark/benchmark_device_merge.parallel.cpp.in +++ b/benchmark/benchmark_device_merge.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -29,6 +29,6 @@ namespace { -auto benchmarks = config_autotune_register::create_bulk( +auto unused = benchmark_utils::executor::queue_autotune( device_merge_benchmark_generator<@KeyType@, @ValueType@, @BlockSize@>::create); } diff --git a/benchmark/benchmark_device_merge.parallel.hpp b/benchmark/benchmark_device_merge.parallel.hpp index ec6df61fb..0e5ac173e 100644 --- a/benchmark/benchmark_device_merge.parallel.hpp +++ b/benchmark/benchmark_device_merge.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,11 +25,12 @@ #include "benchmark_utils.hpp" +#include "../common/utils_device_ptr.hpp" + // Google Benchmark #include // HIP API -#include #include // rocPRIM HIP API @@ -66,7 +67,7 @@ inline std::string config_name() template -struct device_merge_benchmark : public config_autotune_interface +struct device_merge_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -76,17 +77,15 @@ struct device_merge_benchmark : public config_autotune_interface + ",cfg:" + config_name() + "}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - // keys benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const -> + auto do_run(benchmark_utils::state&& state) const -> typename std::enable_if::value, void>::type { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = KeyType; using compare_op_type = typename std::conditional::value, @@ -114,106 +113,52 @@ struct device_merge_benchmark : public config_autotune_interface std::sort(keys_input1.begin(), keys_input1.end(), compare_op); std::sort(keys_input2.begin(), keys_input2.end(), compare_op); - key_type* d_keys_input1; - key_type* d_keys_input2; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input1), size1 * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input2), size2 * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_output), size * sizeof(key_type))); - HIP_CHECK(hipMemcpy(d_keys_input1, - keys_input1.data(), - size1 * sizeof(key_type), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(d_keys_input2, - keys_input2.data(), - size2 * sizeof(key_type), - hipMemcpyHostToDevice)); - - void* d_temporary_storage = nullptr; + common::device_ptr d_keys_input1(keys_input1); + common::device_ptr d_keys_input2(keys_input2); + common::device_ptr d_keys_output(size); + + common::device_ptr d_temporary_storage; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::merge(d_temporary_storage, + HIP_CHECK(rocprim::merge(d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input1, - d_keys_input2, - d_keys_output, + d_keys_input1.get(), + d_keys_input2.get(), + d_keys_output.get(), size1, size2, compare_op, stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::merge(d_temporary_storage, - temporary_storage_bytes, - d_keys_input1, - d_keys_input2, - d_keys_output, - size1, - size2, - compare_op, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); + d_temporary_storage.resize(temporary_storage_bytes); - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::merge(d_temporary_storage, + HIP_CHECK(rocprim::merge(d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input1, - d_keys_input2, - d_keys_output, + d_keys_input1.get(), + d_keys_input2.get(), + d_keys_output.get(), size1, size2, compare_op, stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } + }); - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input1)); - HIP_CHECK(hipFree(d_keys_input2)); - HIP_CHECK(hipFree(d_keys_output)); + state.set_throughput(size, sizeof(key_type)); } // pairs benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const -> + auto do_run(benchmark_utils::state&& state) const -> typename std::enable_if::value, void>::type { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = KeyType; using value_type = ValueType; using compare_op_type = @@ -245,126 +190,56 @@ struct device_merge_benchmark : public config_autotune_interface std::iota(values_input1.begin(), values_input1.end(), 0); std::iota(values_input2.begin(), values_input2.end(), size1); - key_type* d_keys_input1; - key_type* d_keys_input2; - key_type* d_keys_output; - value_type* d_values_input1; - value_type* d_values_input2; - value_type* d_values_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input1), size1 * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input2), size2 * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_output), size * sizeof(key_type))); - HIP_CHECK( - hipMalloc(reinterpret_cast(&d_values_input1), size1 * sizeof(value_type))); - HIP_CHECK( - hipMalloc(reinterpret_cast(&d_values_input2), size2 * sizeof(value_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_output), size * sizeof(value_type))); - HIP_CHECK(hipMemcpy(d_keys_input1, - keys_input1.data(), - size1 * sizeof(key_type), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(d_keys_input2, - keys_input2.data(), - size2 * sizeof(key_type), - hipMemcpyHostToDevice)); - - void* d_temporary_storage = nullptr; + common::device_ptr d_keys_input1(keys_input1); + common::device_ptr d_keys_input2(keys_input2); + common::device_ptr d_keys_output(size); + common::device_ptr d_values_input1(size1); + common::device_ptr d_values_input2(size2); + common::device_ptr d_values_output(size); + + common::device_ptr d_temporary_storage; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::merge(d_temporary_storage, + HIP_CHECK(rocprim::merge(d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input1, - d_keys_input2, - d_keys_output, - d_values_input1, - d_values_input2, - d_values_output, + d_keys_input1.get(), + d_keys_input2.get(), + d_keys_output.get(), + d_values_input1.get(), + d_values_input2.get(), + d_values_output.get(), size1, size2, compare_op, stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::merge(d_temporary_storage, - temporary_storage_bytes, - d_keys_input1, - d_keys_input2, - d_keys_output, - d_values_input1, - d_values_input2, - d_values_output, - size1, - size2, - compare_op, - stream, - false)); - } + d_temporary_storage.resize(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::merge(d_temporary_storage, + HIP_CHECK(rocprim::merge(d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input1, - d_keys_input2, - d_keys_output, - d_values_input1, - d_values_input2, - d_values_output, + d_keys_input1.get(), + d_keys_input2.get(), + d_keys_output.get(), + d_values_input1.get(), + d_values_input2.get(), + d_values_output.get(), size1, size2, compare_op, stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } + }); - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size - * (sizeof(key_type) + sizeof(value_type))); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input1)); - HIP_CHECK(hipFree(d_keys_input2)); - HIP_CHECK(hipFree(d_keys_output)); - HIP_CHECK(hipFree(d_values_input1)); - HIP_CHECK(hipFree(d_values_input2)); - HIP_CHECK(hipFree(d_values_output)); + state.set_throughput(size, sizeof(key_type) + sizeof(value_type)); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { - do_run(state, bytes, seed, stream); + do_run(std::forward(state)); } }; @@ -380,7 +255,7 @@ struct device_merge_benchmark_generator using generated_config = rocprim::merge_config; using benchmark_struct = device_merge_benchmark; - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { storage.emplace_back(std::make_unique()); } @@ -392,13 +267,13 @@ struct device_merge_benchmark_generator typename rocprim::detail::default_merge_config_base::type; using benchmark_struct = device_merge_benchmark; - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { storage.emplace_back(std::make_unique()); } }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { static constexpr unsigned int min_items_per_thread_exponent = 0u; diff --git a/benchmark/benchmark_device_merge_inplace.cpp b/benchmark/benchmark_device_merge_inplace.cpp index 50ec39874..625c7f93f 100644 --- a/benchmark/benchmark_device_merge_inplace.cpp +++ b/benchmark/benchmark_device_merge_inplace.cpp @@ -21,18 +21,16 @@ // SOFTWARE. #include "benchmark_utils.hpp" -#include "cmdparser.hpp" #include "../common/utils_data_generation.hpp" - -#include +#include "../common/utils_device_ptr.hpp" #include #include #include #include -#include +#include #include #include @@ -43,10 +41,6 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - template struct random_monotonic_iterator { @@ -107,7 +101,7 @@ struct inplace_runner size_t right_size; hipStream_t stream; - void* d_temporary_storage = nullptr; + common::device_ptr d_temporary_storage; size_t temporary_storage_bytes = 0; compare_op_type compare_op{}; @@ -116,23 +110,21 @@ struct inplace_runner : d_data(data), left_size(left_size), right_size(right_size), stream(stream) {} - size_t prepare() + void prepare() { - HIP_CHECK(rocprim::merge_inplace(d_temporary_storage, + HIP_CHECK(rocprim::merge_inplace(d_temporary_storage.get(), temporary_storage_bytes, d_data, left_size, right_size, compare_op, stream)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - - return temporary_storage_bytes; + d_temporary_storage.resize(temporary_storage_bytes); } void run() { - HIP_CHECK(rocprim::merge_inplace(d_temporary_storage, + HIP_CHECK(rocprim::merge_inplace(d_temporary_storage.get(), temporary_storage_bytes, d_data, left_size, @@ -140,20 +132,16 @@ struct inplace_runner compare_op, stream)); } - - void clean() - { - HIP_CHECK(hipFree(d_temporary_storage)); - } }; template -void run_merge_inplace_benchmarks(benchmark::State& state, - size_t size_a, - size_t size_b, - const managed_seed& seed, - hipStream_t stream) +void run_merge_inplace_benchmarks(benchmark_utils::state&& state) { + const auto& stream = state.stream; + const auto& size_a = state.bytes; + const auto& size_b = state.bytes; + const auto& seed = state.seed; + using value_type = ValueT; using runner_type = RunnerT; @@ -175,117 +163,39 @@ void run_merge_inplace_benchmarks(benchmark::State& state, h_data[size_a + i] = static_cast(*(gen_b_it++)); } - size_t num_bytes = total_size * sizeof(value_type); - - value_type* d_data; - - HIP_CHECK(hipMalloc(&d_data, num_bytes)); - - runner_type runner{d_data, size_a, size_b, stream}; + common::device_ptr d_data(total_size); - size_t temp_storage_size = runner.prepare(); + runner_type runner{d_data.get(), size_a, size_b, stream}; - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); + runner.prepare(); - for(auto _ : state) - { - HIP_CHECK(hipMemcpy(d_data, h_data.data(), num_bytes, hipMemcpyHostToDevice)); + state.run_before_every_iteration([&] { d_data.store(h_data); }); - HIP_CHECK(hipEventRecord(start, stream)); - runner.run(); - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); + state.run([&] { runner.run(); }); - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - - std::stringstream label; - label << "temp_storage=" << temp_storage_size; - - state.SetLabel(label.str()); - } - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * num_bytes); - state.SetItemsProcessed(state.iterations() * total_size); - - HIP_CHECK(hipFree(d_data)); - runner.clean(); + state.set_throughput(total_size, sizeof(value_type)); } -#define CREATE_MERGE_INPLACE_BENCHMARK(Value) \ - benchmark::RegisterBenchmark( \ +#define CREATE_BENCHMARK(Value) \ + executor.queue_fn( \ bench_naming::format_name("{lvl:device,algo:merge_inplace,value_type:" #Value \ ",cfg:default_config}") \ .c_str(), \ - [=](benchmark::State& state) { \ - run_merge_inplace_benchmarks>(state, \ - size, \ - size, \ - seed, \ - stream); \ - }) + [=](benchmark_utils::state&& state) \ + { \ + run_merge_inplace_benchmarks>( \ + std::forward(state)); \ + }); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of values"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = hipStreamDefault; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = { - CREATE_MERGE_INPLACE_BENCHMARK(int8_t), - CREATE_MERGE_INPLACE_BENCHMARK(int16_t), - CREATE_MERGE_INPLACE_BENCHMARK(int32_t), - CREATE_MERGE_INPLACE_BENCHMARK(int64_t), - CREATE_MERGE_INPLACE_BENCHMARK(rocprim::int128_t), - }; - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 1, 0); - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); + CREATE_BENCHMARK(int8_t) + CREATE_BENCHMARK(int16_t) + CREATE_BENCHMARK(int32_t) + CREATE_BENCHMARK(int64_t) + CREATE_BENCHMARK(rocprim::int128_t) - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_merge_sort.cpp b/benchmark/benchmark_device_merge_sort.cpp index e237ca3c6..1c818c03d 100644 --- a/benchmark/benchmark_device_merge_sort.cpp +++ b/benchmark/benchmark_device_merge_sort.cpp @@ -23,14 +23,8 @@ #include "benchmark_device_merge_sort.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - #include "../common/utils_custom_type.hpp" -// Google Benchmark -#include - // HIP API #include @@ -41,46 +35,12 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_BENCHMARK(...) \ - { \ - const device_merge_sort_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } +#define CREATE_BENCHMARK(...) executor.queue_instance(device_merge_sort_benchmark<__VA_ARGS__>()); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; CREATE_BENCHMARK(int) CREATE_BENCHMARK(long long) CREATE_BENCHMARK(int8_t) @@ -115,23 +75,5 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t) CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t) - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_merge_sort.hpp b/benchmark/benchmark_device_merge_sort.hpp index 33754645b..b72a06aca 100644 --- a/benchmark/benchmark_device_merge_sort.hpp +++ b/benchmark/benchmark_device_merge_sort.hpp @@ -45,7 +45,7 @@ namespace rp = rocprim; template -struct device_merge_sort_benchmark : public config_autotune_interface +struct device_merge_sort_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -54,17 +54,15 @@ struct device_merge_sort_benchmark : public config_autotune_interface + ",value_type:" + std::string(Traits::name()) + ",cfg:default_config}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - // keys benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const -> + auto do_run(benchmark_utils::state&& state) const -> typename std::enable_if::value, void>::type { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; // Calculate the number of elements @@ -101,31 +99,8 @@ struct device_merge_sort_benchmark : public config_autotune_interface HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rp::merge_sort(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - size, - lesser_op, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { HIP_CHECK(rp::merge_sort(d_temporary_storage, temporary_storage_bytes, @@ -135,24 +110,9 @@ struct device_merge_sort_benchmark : public config_autotune_interface lesser_op, stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); + }); + state.set_throughput(size, sizeof(key_type)); HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_keys_input)); HIP_CHECK(hipFree(d_keys_output)); @@ -160,12 +120,13 @@ struct device_merge_sort_benchmark : public config_autotune_interface // pairs benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const -> + auto do_run(benchmark_utils::state&& state) const -> typename std::enable_if::value, void>::type { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; using value_type = Value; @@ -217,33 +178,8 @@ struct device_merge_sort_benchmark : public config_autotune_interface HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rp::merge_sort(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - lesser_op, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { HIP_CHECK(rp::merge_sort(d_temporary_storage, temporary_storage_bytes, @@ -255,24 +191,9 @@ struct device_merge_sort_benchmark : public config_autotune_interface lesser_op, stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + }); - state.SetBytesProcessed(state.iterations() * batch_size * size - * (sizeof(key_type) + sizeof(value_type))); - state.SetItemsProcessed(state.iterations() * batch_size * size); + state.set_throughput(size, sizeof(key_type) + sizeof(value_type)); HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_keys_input)); @@ -281,12 +202,9 @@ struct device_merge_sort_benchmark : public config_autotune_interface HIP_CHECK(hipFree(d_values_output)); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { - do_run(state, bytes, seed, stream); + do_run(std::forward(state)); } }; diff --git a/benchmark/benchmark_device_merge_sort_block_merge.cpp b/benchmark/benchmark_device_merge_sort_block_merge.cpp index 813d485e0..6fcc77462 100644 --- a/benchmark/benchmark_device_merge_sort_block_merge.cpp +++ b/benchmark/benchmark_device_merge_sort_block_merge.cpp @@ -23,16 +23,10 @@ #include "benchmark_device_merge_sort_block_merge.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - #ifndef BENCHMARK_CONFIG_TUNING #include "../common/utils_custom_type.hpp" #endif -// Google Benchmark -#include - // HIP API #include @@ -47,67 +41,14 @@ #include #endif -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_BENCHMARK(...) \ - { \ - const device_merge_sort_block_merge_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } +#define CREATE_BENCHMARK(...) \ + executor.queue_instance(device_merge_sort_block_merge_benchmark<__VA_ARGS__>()); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else // BENCHMARK_CONFIG_TUNING +#ifndef BENCHMARK_CONFIG_TUNING CREATE_BENCHMARK(int) CREATE_BENCHMARK(long long) CREATE_BENCHMARK(int8_t) @@ -137,25 +78,7 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(custom_int2, custom_longlong_double) CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t) CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t) -#endif // BENCHMARK_CONFIG_TUNING - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } +#endif - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_merge_sort_block_merge.parallel.cpp.in b/benchmark/benchmark_device_merge_sort_block_merge.parallel.cpp.in index 5d34eacb3..63b3d5a91 100644 --- a/benchmark/benchmark_device_merge_sort_block_merge.parallel.cpp.in +++ b/benchmark/benchmark_device_merge_sort_block_merge.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,6 @@ #include namespace { - auto benchmarks = config_autotune_register::create_bulk( + auto unused = benchmark_utils::executor::queue_autotune( device_merge_sort_block_merge_benchmark_generator<@BlockSize@, @UseMergePath@, @KeyType@, @ValueType@>::create); } diff --git a/benchmark/benchmark_device_merge_sort_block_merge.parallel.hpp b/benchmark/benchmark_device_merge_sort_block_merge.parallel.hpp index ff9a1a804..16ab693e2 100644 --- a/benchmark/benchmark_device_merge_sort_block_merge.parallel.hpp +++ b/benchmark/benchmark_device_merge_sort_block_merge.parallel.hpp @@ -26,6 +26,7 @@ #include "benchmark_utils.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -69,7 +70,7 @@ inline std::string config_name() template -struct device_merge_sort_block_merge_benchmark : public config_autotune_interface +struct device_merge_sort_block_merge_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -79,19 +80,18 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac + ",cfg:" + config_name() + "}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; // Because merge_sort_block_merge expects partially sorted input: using block_sort_config = rocprim::default_config; // keys benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const -> + auto do_run(benchmark_utils::state&& state) const -> typename std::enable_if::value, void>::type { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; // Calculate the number of elements @@ -104,14 +104,8 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac common::generate_limits::max(), seed.get_0()); - key_type* d_keys_input; - key_type* d_keys; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input), size * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys), size * sizeof(key_type))); - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice)); + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_keys(size); HIP_CHECK(hipDeviceSynchronize()); ::rocprim::less lesser_op; @@ -119,8 +113,8 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac // Merge_sort_block_merge algorithm expects partially sorted input: unsigned int sorted_block_size; - HIP_CHECK(rocprim::detail::merge_sort_block_sort(d_keys_input, - d_keys_input, + HIP_CHECK(rocprim::detail::merge_sort_block_sort(d_keys_input.get(), + d_keys_input.get(), values_ptr, values_ptr, size, @@ -129,11 +123,10 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac stream, false)); - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::detail::merge_sort_block_merge(d_temporary_storage, + HIP_CHECK(rocprim::detail::merge_sort_block_merge(nullptr, temporary_storage_bytes, - d_keys, + d_keys.get(), values_ptr, size, sorted_block_size, @@ -141,30 +134,22 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + common::device_ptr d_temporary_storage(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - hipError_t err; - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - err = rocprim::detail::merge_sort_block_merge(d_temporary_storage, - temporary_storage_bytes, - d_keys, - values_ptr, - size, - sorted_block_size, - lesser_op, - stream, - false); - } + hipError_t err = rocprim::detail::merge_sort_block_merge(d_temporary_storage.get(), + temporary_storage_bytes, + d_keys.get(), + values_ptr, + size, + sorted_block_size, + lesser_op, + stream, + false); if(err == hipError_t::hipErrorAssert) { - state.SkipWithError("SKIPPING: block_sort_items_per_block >= " - "block_merge_items_per_block does not hold"); - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys)); + state.gbench_state.SkipWithError("SKIPPING: block_sort_items_per_block >= " + "block_merge_items_per_block does not hold"); return; } else if(err != hipSuccess) @@ -174,59 +159,42 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac } HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipMemcpyAsync(d_keys, - d_keys_input, - size * sizeof(key_type), - hipMemcpyDeviceToDevice, - stream)); - HIP_CHECK(hipEventRecord(start, stream)); - HIP_CHECK(rocprim::detail::merge_sort_block_merge(d_temporary_storage, - temporary_storage_bytes, - d_keys, - values_ptr, - size, - sorted_block_size, - lesser_op, - stream, - false)); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys)); + state.run_before_every_iteration( + [&] + { + HIP_CHECK(hipMemcpyAsync(d_keys.get(), + d_keys_input.get(), + size * sizeof(key_type), + hipMemcpyDeviceToDevice, + stream)); + }); + + state.run( + [&] + { + HIP_CHECK(rocprim::detail::merge_sort_block_merge(d_temporary_storage.get(), + temporary_storage_bytes, + d_keys.get(), + values_ptr, + size, + sorted_block_size, + lesser_op, + stream, + false)); + }); + + state.set_throughput(size, sizeof(key_type)); } // pairs benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const -> + auto do_run(benchmark_utils::state&& state) const -> typename std::enable_if::value, void>::type { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; using value_type = Value; @@ -243,23 +211,11 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac std::vector values_input(size); std::iota(values_input.begin(), values_input.end(), 0); - key_type* d_keys_input; - key_type* d_keys; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input), size * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys), size * sizeof(key_type))); - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice)); - - value_type* d_values_input; - value_type* d_values; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_input), size * sizeof(value_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values), size * sizeof(value_type))); - HIP_CHECK(hipMemcpy(d_values_input, - values_input.data(), - size * sizeof(value_type), - hipMemcpyHostToDevice)); + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_keys(size); + + common::device_ptr d_values_input(values_input); + common::device_ptr d_values(size); HIP_CHECK(hipDeviceSynchronize()); @@ -267,54 +223,43 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac // Merge_sort_block_merge algorithm expects partially sorted input: unsigned int sorted_block_size; - HIP_CHECK(rocprim::detail::merge_sort_block_sort(d_keys_input, - d_keys_input, - d_values_input, - d_values_input, + HIP_CHECK(rocprim::detail::merge_sort_block_sort(d_keys_input.get(), + d_keys_input.get(), + d_values_input.get(), + d_values_input.get(), size, sorted_block_size, lesser_op, stream, false)); - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::detail::merge_sort_block_merge(d_temporary_storage, + HIP_CHECK(rocprim::detail::merge_sort_block_merge(nullptr, temporary_storage_bytes, - d_keys, - d_values, + d_keys.get(), + d_values.get(), size, sorted_block_size, lesser_op, stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + common::device_ptr d_temporary_storage(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - hipError_t err; - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - err = rocprim::detail::merge_sort_block_merge(d_temporary_storage, - temporary_storage_bytes, - d_keys, - d_values, - size, - sorted_block_size, - lesser_op, - stream, - false); - } + hipError_t err = rocprim::detail::merge_sort_block_merge(d_temporary_storage.get(), + temporary_storage_bytes, + d_keys.get(), + d_values.get(), + size, + sorted_block_size, + lesser_op, + stream, + false); if(err == hipError_t::hipErrorAssert) { - state.SkipWithError("SKIPPING: block_sort_items_per_block >= " - "block_merge_items_per_block does not hold"); - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys)); - HIP_CHECK(hipFree(d_values_input)); - HIP_CHECK(hipFree(d_values)); + state.gbench_state.SkipWithError("SKIPPING: block_sort_items_per_block >= " + "block_merge_items_per_block does not hold"); return; } else if(err != hipSuccess) @@ -324,64 +269,41 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac } HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipMemcpyAsync(d_keys, - d_keys_input, - size * sizeof(key_type), - hipMemcpyDeviceToDevice, - stream)); - HIP_CHECK(hipMemcpyAsync(d_values, - d_values_input, - size * sizeof(value_type), - hipMemcpyDeviceToDevice, - stream)); - HIP_CHECK(hipEventRecord(start, stream)); - HIP_CHECK(rocprim::detail::merge_sort_block_merge(d_temporary_storage, - temporary_storage_bytes, - d_keys, - d_values, - size, - sorted_block_size, - lesser_op, - stream, - false)); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys)); - HIP_CHECK(hipFree(d_values_input)); - HIP_CHECK(hipFree(d_values)); + state.run_before_every_iteration( + [&] + { + HIP_CHECK(hipMemcpyAsync(d_keys.get(), + d_keys_input.get(), + size * sizeof(key_type), + hipMemcpyDeviceToDevice, + stream)); + HIP_CHECK(hipMemcpyAsync(d_values.get(), + d_values_input.get(), + size * sizeof(value_type), + hipMemcpyDeviceToDevice, + stream)); + }); + + state.run( + [&] + { + HIP_CHECK(rocprim::detail::merge_sort_block_merge(d_temporary_storage.get(), + temporary_storage_bytes, + d_keys.get(), + d_values.get(), + size, + sorted_block_size, + lesser_op, + stream, + false)); + }); + + state.set_throughput(size, sizeof(key_type)); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { - do_run(state, bytes, seed, stream); + do_run(std::forward(state)); } }; @@ -409,13 +331,13 @@ struct device_merge_sort_block_merge_benchmark_generator using benchmark_struct = device_merge_sort_block_merge_benchmark; - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { storage.emplace_back(std::make_unique()); } }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { static constexpr unsigned int min_items_per_thread_exponent = 0u; diff --git a/benchmark/benchmark_device_merge_sort_block_sort.cpp b/benchmark/benchmark_device_merge_sort_block_sort.cpp index c7f60596d..bfc8d16d9 100644 --- a/benchmark/benchmark_device_merge_sort_block_sort.cpp +++ b/benchmark/benchmark_device_merge_sort_block_sort.cpp @@ -23,16 +23,10 @@ #include "benchmark_device_merge_sort_block_sort.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - #ifndef BENCHMARK_CONFIG_TUNING #include "../common/utils_custom_type.hpp" #endif -// Google Benchmark -#include - // HIP API #include @@ -47,67 +41,14 @@ #include #endif -#ifndef DEFAULT_N -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_BENCHMARK(...) \ - { \ - const device_merge_sort_block_sort_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } +#define CREATE_BENCHMARK(...) \ + executor.queue_instance(device_merge_sort_block_sort_benchmark<__VA_ARGS__>()); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else // BENCHMARK_CONFIG_TUNING +#ifndef BENCHMARK_CONFIG_TUNING CREATE_BENCHMARK(int) CREATE_BENCHMARK(long long) CREATE_BENCHMARK(int8_t) @@ -139,25 +80,7 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(int, custom_char_short) CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t) CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t) -#endif // BENCHMARK_CONFIG_TUNING - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } +#endif - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_merge_sort_block_sort.parallel.cpp.in b/benchmark/benchmark_device_merge_sort_block_sort.parallel.cpp.in index c0e241034..845cbbe0f 100644 --- a/benchmark/benchmark_device_merge_sort_block_sort.parallel.cpp.in +++ b/benchmark/benchmark_device_merge_sort_block_sort.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,6 @@ #include namespace { - auto benchmarks = config_autotune_register::create_bulk( + auto unused = benchmark_utils::executor::queue_autotune( device_merge_sort_block_sort_benchmark_generator<@BlockSize@, @BlockSortMethod@, @KeyType@, @ValueType@>::create); } diff --git a/benchmark/benchmark_device_merge_sort_block_sort.parallel.hpp b/benchmark/benchmark_device_merge_sort_block_sort.parallel.hpp index f9baa5308..4a823c943 100644 --- a/benchmark/benchmark_device_merge_sort_block_sort.parallel.hpp +++ b/benchmark/benchmark_device_merge_sort_block_sort.parallel.hpp @@ -26,6 +26,7 @@ #include "benchmark_utils.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -79,7 +80,7 @@ inline std::string config_name() template -struct device_merge_sort_block_sort_benchmark : public config_autotune_interface +struct device_merge_sort_block_sort_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -90,17 +91,15 @@ struct device_merge_sort_block_sort_benchmark : public config_autotune_interface + ",cfg:" + config_name() + "}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - // keys benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const -> + auto do_run(benchmark_utils::state&& state) const -> typename std::enable_if::value, void>::type { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; // Calculate the number of elements @@ -112,47 +111,18 @@ struct device_merge_sort_block_sort_benchmark : public config_autotune_interface common::generate_limits::max(), seed.get_0()); - key_type* d_keys_input; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input), size * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_output), size * sizeof(key_type))); - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice)); + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_keys_output(size); ::rocprim::less lesser_op; rocprim::empty_type* values_ptr = nullptr; unsigned int items_per_block; - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::detail::merge_sort_block_sort(d_keys_input, - d_keys_output, - values_ptr, - values_ptr, - size, - items_per_block, - lesser_op, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::detail::merge_sort_block_sort(d_keys_input, - d_keys_output, + HIP_CHECK(rocprim::detail::merge_sort_block_sort(d_keys_input.get(), + d_keys_output.get(), values_ptr, values_ptr, size, @@ -160,36 +130,20 @@ struct device_merge_sort_block_sort_benchmark : public config_autotune_interface lesser_op, stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } + }); - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); + state.set_throughput(size, sizeof(key_type)); } // pairs benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const -> + auto do_run(benchmark_utils::state&& state) const -> typename std::enable_if::value, void>::type { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; using value_type = Value; @@ -205,97 +159,37 @@ struct device_merge_sort_block_sort_benchmark : public config_autotune_interface std::vector values_input(size); std::iota(values_input.begin(), values_input.end(), 0); - key_type* d_keys_input; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input), size * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_output), size * sizeof(key_type))); - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice)); - - value_type* d_values_input; - value_type* d_values_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_input), size * sizeof(value_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_output), size * sizeof(value_type))); - HIP_CHECK(hipMemcpy(d_values_input, - values_input.data(), - size * sizeof(value_type), - hipMemcpyHostToDevice)); + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_keys_output(size); + + common::device_ptr d_values_input(values_input); + common::device_ptr d_values_output(size); ::rocprim::less lesser_op; unsigned int items_per_block; HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - - HIP_CHECK(rocprim::detail::merge_sort_block_sort(d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - items_per_block, - lesser_op, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::detail::merge_sort_block_sort(d_keys_input, - d_keys_output, - d_values_input, - d_values_output, + HIP_CHECK(rocprim::detail::merge_sort_block_sort(d_keys_input.get(), + d_keys_output.get(), + d_values_input.get(), + d_values_output.get(), size, items_per_block, lesser_op, stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size - * (sizeof(key_type) + sizeof(value_type))); - state.SetItemsProcessed(state.iterations() * batch_size * size); + }); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); - HIP_CHECK(hipFree(d_values_input)); - HIP_CHECK(hipFree(d_values_output)); + state.set_throughput(size, sizeof(key_type) + sizeof(value_type)); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { - do_run(state, bytes, seed, stream); + do_run(std::forward(state)); } }; @@ -312,7 +206,7 @@ struct device_merge_sort_block_sort_benchmark_generator using generated_config = rocprim::detail::merge_sort_block_sort_config; - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { storage.emplace_back( std::make_unique< @@ -320,7 +214,7 @@ struct device_merge_sort_block_sort_benchmark_generator } }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { // Sort_items_per_block must be equal or larger than merge_items_per_block, so make // the items_per_thread at least as large so the sort_items_per_block diff --git a/benchmark/benchmark_device_nth_element.cpp b/benchmark/benchmark_device_nth_element.cpp index 37329a4eb..43741c980 100644 --- a/benchmark/benchmark_device_nth_element.cpp +++ b/benchmark/benchmark_device_nth_element.cpp @@ -23,14 +23,8 @@ #include "benchmark_device_nth_element.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - #include "../common/utils_custom_type.hpp" -// Google Benchmark -#include - // HIP API #include @@ -41,15 +35,8 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_BENCHMARK_NTH_ELEMENT(TYPE, SMALL_N) \ - { \ - const device_nth_element_benchmark instance(SMALL_N); \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } +#define CREATE_BENCHMARK_NTH_ELEMENT(TYPE, SMALL_N) \ + executor.queue_instance(device_nth_element_benchmark(SMALL_N)); #define CREATE_BENCHMARK(TYPE) \ { \ @@ -58,34 +45,8 @@ const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks{}; CREATE_BENCHMARK(int) CREATE_BENCHMARK(long long) CREATE_BENCHMARK(int8_t) @@ -108,23 +69,5 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(custom_char_double) CREATE_BENCHMARK(custom_longlong_double) - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_nth_element.hpp b/benchmark/benchmark_device_nth_element.hpp index 2edcce832..b46a7ef08 100644 --- a/benchmark/benchmark_device_nth_element.hpp +++ b/benchmark/benchmark_device_nth_element.hpp @@ -26,6 +26,7 @@ #include "benchmark_utils.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -43,7 +44,7 @@ #include template -struct device_nth_element_benchmark : public config_autotune_interface +struct device_nth_element_benchmark : public benchmark_utils::autotune_interface { bool small_n = false; @@ -60,14 +61,12 @@ struct device_nth_element_benchmark : public config_autotune_interface + ",key_type:" + std::string(Traits::name()) + ",cfg:default_config}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; // Calculate the number of elements @@ -86,90 +85,39 @@ struct device_nth_element_benchmark : public config_autotune_interface common::generate_limits::max(), seed.get_0()); - key_type* d_keys_input; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(*d_keys_input))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(*d_keys_output))); - - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - size * sizeof(*d_keys_input), - hipMemcpyHostToDevice)); + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_keys_output(size); ::rocprim::less lesser_op; - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::nth_element(d_temporary_storage, + HIP_CHECK(rocprim::nth_element(nullptr, temporary_storage_bytes, - d_keys_input, - d_keys_output, + d_keys_input.get(), + d_keys_output.get(), nth, size, lesser_op, stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::nth_element(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - nth, - size, - lesser_op, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); + common::device_ptr d_temporary_storage(temporary_storage_bytes); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::nth_element(d_temporary_storage, + HIP_CHECK(rocprim::nth_element(d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input, - d_keys_output, + d_keys_input.get(), + d_keys_output.get(), nth, size, lesser_op, stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } + }); - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(*d_keys_input)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); + state.set_throughput(size, sizeof(key_type)); } }; - #endif // ROCPRIM_BENCHMARK_DEVICE_NTH_ELEMENT_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_partial_sort.cpp b/benchmark/benchmark_device_partial_sort.cpp index ef5de791a..97fb8ab23 100644 --- a/benchmark/benchmark_device_partial_sort.cpp +++ b/benchmark/benchmark_device_partial_sort.cpp @@ -23,14 +23,8 @@ #include "benchmark_device_partial_sort.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - #include "../common/utils_custom_type.hpp" -// Google Benchmark -#include - // HIP API #include @@ -41,15 +35,8 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_BENCHMARK_PARTIAL_SORT(TYPE, SMALL_N) \ - { \ - const device_partial_sort_benchmark instance(SMALL_N); \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } +#define CREATE_BENCHMARK_PARTIAL_SORT(TYPE, SMALL_N) \ + executor.queue_instance(device_partial_sort_benchmark(SMALL_N)); #define CREATE_BENCHMARK(TYPE) \ { \ @@ -58,34 +45,9 @@ const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - // HIP - hipStream_t stream = 0; // default + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks{}; CREATE_BENCHMARK(int) CREATE_BENCHMARK(long long) CREATE_BENCHMARK(int8_t) @@ -108,23 +70,5 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(custom_char_double) CREATE_BENCHMARK(custom_longlong_double) - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_partial_sort.hpp b/benchmark/benchmark_device_partial_sort.hpp index 413aa19b7..4c6d44341 100644 --- a/benchmark/benchmark_device_partial_sort.hpp +++ b/benchmark/benchmark_device_partial_sort.hpp @@ -43,7 +43,7 @@ #include template -struct device_partial_sort_benchmark : public config_autotune_interface +struct device_partial_sort_benchmark : public benchmark_utils::autotune_interface { bool small_n = false; @@ -60,14 +60,13 @@ struct device_partial_sort_benchmark : public config_autotune_interface + ",key_type:" + std::string(Traits::name()) + ",cfg:default_config}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; // Calculate the number of elements size_t size = bytes / sizeof(key_type); @@ -110,40 +109,18 @@ struct device_partial_sort_benchmark : public config_autotune_interface HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(hipMemcpy(d_keys_input, - d_keys_new_data, - size * sizeof(*d_keys_input), - hipMemcpyDeviceToDevice)); - HIP_CHECK(rocprim::partial_sort(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - middle, - size, - lesser_op, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - float elapsed_mseconds = 0; - for(size_t i = 0; i < batch_size; ++i) + state.run_before_every_iteration( + [&] { HIP_CHECK(hipMemcpy(d_keys_input, d_keys_new_data, size * sizeof(*d_keys_input), hipMemcpyDeviceToDevice)); - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); + }); + + state.run( + [&] + { HIP_CHECK(rocprim::partial_sort(d_temporary_storage, temporary_storage_bytes, d_keys_input, @@ -152,23 +129,9 @@ struct device_partial_sort_benchmark : public config_autotune_interface lesser_op, stream, false)); - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - float elapsed_mseconds_current; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds_current, start, stop)); - elapsed_mseconds += elapsed_mseconds_current; - } - - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + }); - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(*d_keys_input)); - state.SetItemsProcessed(state.iterations() * batch_size * size); + state.set_throughput(size, sizeof(key_type)); HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_keys_input)); diff --git a/benchmark/benchmark_device_partial_sort_copy.cpp b/benchmark/benchmark_device_partial_sort_copy.cpp index b7d2394a8..27e375fe2 100644 --- a/benchmark/benchmark_device_partial_sort_copy.cpp +++ b/benchmark/benchmark_device_partial_sort_copy.cpp @@ -23,14 +23,8 @@ #include "benchmark_device_partial_sort_copy.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - #include "../common/utils_custom_type.hpp" -// Google Benchmark -#include - // HIP API #include @@ -41,15 +35,8 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_BENCHMARK_PARTIAL_SORT_COPY(TYPE, SMALL_N) \ - { \ - const device_partial_sort_copy_benchmark instance(SMALL_N); \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } +#define CREATE_BENCHMARK_PARTIAL_SORT_COPY(TYPE, SMALL_N) \ + executor.queue_instance(device_partial_sort_copy_benchmark(SMALL_N)); #define CREATE_BENCHMARK(TYPE) \ { \ @@ -59,34 +46,8 @@ const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks{}; CREATE_BENCHMARK(int) CREATE_BENCHMARK(long long) CREATE_BENCHMARK(int8_t) @@ -109,23 +70,5 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(custom_char_double) CREATE_BENCHMARK(custom_longlong_double) - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_partial_sort_copy.hpp b/benchmark/benchmark_device_partial_sort_copy.hpp index f3e8bfd22..d5872432b 100644 --- a/benchmark/benchmark_device_partial_sort_copy.hpp +++ b/benchmark/benchmark_device_partial_sort_copy.hpp @@ -26,6 +26,7 @@ #include "benchmark_utils.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -43,7 +44,7 @@ #include template -struct device_partial_sort_copy_benchmark : public config_autotune_interface +struct device_partial_sort_copy_benchmark : public benchmark_utils::autotune_interface { bool small_n = false; @@ -60,14 +61,12 @@ struct device_partial_sort_copy_benchmark : public config_autotune_interface + ",key_type:" + std::string(Traits::name()) + ",cfg:default_config}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; // Calculate the number of elements @@ -86,88 +85,39 @@ struct device_partial_sort_copy_benchmark : public config_autotune_interface common::generate_limits::max(), seed.get_0()); - key_type* d_keys_input; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(*d_keys_input))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(*d_keys_output))); - - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - size * sizeof(*d_keys_input), - hipMemcpyHostToDevice)); + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_keys_output(size); rocprim::less lesser_op; - void* d_temporary_storage = nullptr; + common::device_ptr d_temporary_storage; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, + HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input, - d_keys_output, + d_keys_input.get(), + d_keys_output.get(), middle, size, lesser_op, stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - middle, - size, - lesser_op, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); + d_temporary_storage.resize(temporary_storage_bytes); - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage, + HIP_CHECK(rocprim::partial_sort_copy(d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input, - d_keys_output, + d_keys_input.get(), + d_keys_output.get(), middle, size, lesser_op, stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(*d_keys_input)); - state.SetItemsProcessed(state.iterations() * batch_size * size); + }); - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); + state.set_throughput(size, sizeof(key_type)); } }; diff --git a/benchmark/benchmark_device_partition.cpp b/benchmark/benchmark_device_partition.cpp index 531bb8ccb..579b0b730 100644 --- a/benchmark/benchmark_device_partition.cpp +++ b/benchmark/benchmark_device_partition.cpp @@ -22,16 +22,11 @@ #include "benchmark_device_partition.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #ifndef BENCHMARK_CONFIG_TUNING #include "../common/utils_custom_type.hpp" #endif -// Google Benchmark -#include - // HIP API #include @@ -48,200 +43,117 @@ #include #endif -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif +#define CREATE_PARTITION_FLAG_BENCHMARK(T, F, p) \ + executor.queue_instance(device_partition_flag_benchmark()); + +#define CREATE_PARTITION_PREDICATE_BENCHMARK(T, p) \ + executor.queue_instance(device_partition_predicate_benchmark()); + +#define CREATE_PARTITION_TWO_WAY_FLAG_BENCHMARK(T, F, p) \ + executor.queue_instance( \ + device_partition_two_way_flag_benchmark()); + +#define CREATE_PARTITION_TWO_WAY_PREDICATE_BENCHMARK(T, p) \ + executor.queue_instance( \ + device_partition_two_way_predicate_benchmark()); -#define CREATE_PARTITION_FLAG_BENCHMARK(T, F, p) \ - { \ - const device_partition_flag_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_PARTITION_PREDICATE_BENCHMARK(T, p) \ - { \ - const device_partition_predicate_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_PARTITION_TWO_WAY_FLAG_BENCHMARK(T, F, p) \ - { \ - const device_partition_two_way_flag_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_PARTITION_TWO_WAY_PREDICATE_BENCHMARK(T, p) \ - { \ - const device_partition_two_way_predicate_benchmark \ - instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_PARTITION_THREE_WAY_BENCHMARK(T, p) \ - { \ - const device_partition_three_way_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define BENCHMARK_FLAG_TYPE(type, flag_type) \ - CREATE_PARTITION_FLAG_BENCHMARK(type, flag_type, partition_probability::p005); \ - CREATE_PARTITION_FLAG_BENCHMARK(type, flag_type, partition_probability::p025); \ - CREATE_PARTITION_FLAG_BENCHMARK(type, flag_type, partition_probability::p050); \ +#define CREATE_PARTITION_THREE_WAY_BENCHMARK(T, p) \ + executor.queue_instance(device_partition_three_way_benchmark()); + +#define BENCHMARK_FLAG_TYPE(type, flag_type) \ + CREATE_PARTITION_FLAG_BENCHMARK(type, flag_type, partition_probability::p005) \ + CREATE_PARTITION_FLAG_BENCHMARK(type, flag_type, partition_probability::p025) \ + CREATE_PARTITION_FLAG_BENCHMARK(type, flag_type, partition_probability::p050) \ CREATE_PARTITION_FLAG_BENCHMARK(type, flag_type, partition_probability::p075) -#define BENCHMARK_PREDICATE_TYPE(type) \ - CREATE_PARTITION_PREDICATE_BENCHMARK(type, partition_probability::p005); \ - CREATE_PARTITION_PREDICATE_BENCHMARK(type, partition_probability::p025); \ - CREATE_PARTITION_PREDICATE_BENCHMARK(type, partition_probability::p050); \ +#define BENCHMARK_PREDICATE_TYPE(type) \ + CREATE_PARTITION_PREDICATE_BENCHMARK(type, partition_probability::p005) \ + CREATE_PARTITION_PREDICATE_BENCHMARK(type, partition_probability::p025) \ + CREATE_PARTITION_PREDICATE_BENCHMARK(type, partition_probability::p050) \ CREATE_PARTITION_PREDICATE_BENCHMARK(type, partition_probability::p075) -#define BENCHMARK_TWO_WAY_FLAG_TYPE(type, flag_type) \ - CREATE_PARTITION_TWO_WAY_FLAG_BENCHMARK(type, flag_type, partition_probability::p005); \ - CREATE_PARTITION_TWO_WAY_FLAG_BENCHMARK(type, flag_type, partition_probability::p025); \ - CREATE_PARTITION_TWO_WAY_FLAG_BENCHMARK(type, flag_type, partition_probability::p050); \ +#define BENCHMARK_TWO_WAY_FLAG_TYPE(type, flag_type) \ + CREATE_PARTITION_TWO_WAY_FLAG_BENCHMARK(type, flag_type, partition_probability::p005) \ + CREATE_PARTITION_TWO_WAY_FLAG_BENCHMARK(type, flag_type, partition_probability::p025) \ + CREATE_PARTITION_TWO_WAY_FLAG_BENCHMARK(type, flag_type, partition_probability::p050) \ CREATE_PARTITION_TWO_WAY_FLAG_BENCHMARK(type, flag_type, partition_probability::p075) -#define BENCHMARK_TWO_WAY_PREDICATE_TYPE(type) \ - CREATE_PARTITION_TWO_WAY_PREDICATE_BENCHMARK(type, partition_probability::p005); \ - CREATE_PARTITION_TWO_WAY_PREDICATE_BENCHMARK(type, partition_probability::p025); \ - CREATE_PARTITION_TWO_WAY_PREDICATE_BENCHMARK(type, partition_probability::p050); \ +#define BENCHMARK_TWO_WAY_PREDICATE_TYPE(type) \ + CREATE_PARTITION_TWO_WAY_PREDICATE_BENCHMARK(type, partition_probability::p005) \ + CREATE_PARTITION_TWO_WAY_PREDICATE_BENCHMARK(type, partition_probability::p025) \ + CREATE_PARTITION_TWO_WAY_PREDICATE_BENCHMARK(type, partition_probability::p050) \ CREATE_PARTITION_TWO_WAY_PREDICATE_BENCHMARK(type, partition_probability::p075) -#define BENCHMARK_THREE_WAY_TYPE(type) \ - CREATE_PARTITION_THREE_WAY_BENCHMARK(type, partition_three_way_probability::p005_p025); \ - CREATE_PARTITION_THREE_WAY_BENCHMARK(type, partition_three_way_probability::p025_p050); \ - CREATE_PARTITION_THREE_WAY_BENCHMARK(type, partition_three_way_probability::p050_p075); \ +#define BENCHMARK_THREE_WAY_TYPE(type) \ + CREATE_PARTITION_THREE_WAY_BENCHMARK(type, partition_three_way_probability::p005_p025) \ + CREATE_PARTITION_THREE_WAY_BENCHMARK(type, partition_three_way_probability::p025_p050) \ + CREATE_PARTITION_THREE_WAY_BENCHMARK(type, partition_three_way_probability::p050_p075) \ CREATE_PARTITION_THREE_WAY_BENCHMARK(type, partition_three_way_probability::p075_p100) int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); + +#ifndef BENCHMARK_CONFIG_TUNING using custom_double2 = common::custom_type; using custom_int_double = common::custom_type; using huge_float2 = common::custom_huge_type<1024, float, float>; - BENCHMARK_FLAG_TYPE(int, unsigned char); - BENCHMARK_FLAG_TYPE(float, unsigned char); - BENCHMARK_FLAG_TYPE(double, unsigned char); - BENCHMARK_FLAG_TYPE(uint8_t, uint8_t); - BENCHMARK_FLAG_TYPE(int8_t, int8_t); - BENCHMARK_FLAG_TYPE(rocprim::half, int8_t); - BENCHMARK_FLAG_TYPE(custom_double2, unsigned char); - BENCHMARK_FLAG_TYPE(rocprim::int128_t, int8_t); - BENCHMARK_FLAG_TYPE(rocprim::uint128_t, uint8_t); - BENCHMARK_FLAG_TYPE(huge_float2, uint8_t); - - BENCHMARK_PREDICATE_TYPE(int); - BENCHMARK_PREDICATE_TYPE(float); - BENCHMARK_PREDICATE_TYPE(double); - BENCHMARK_PREDICATE_TYPE(uint8_t); - BENCHMARK_PREDICATE_TYPE(int8_t); - BENCHMARK_PREDICATE_TYPE(rocprim::half); - BENCHMARK_PREDICATE_TYPE(custom_int_double); - BENCHMARK_PREDICATE_TYPE(rocprim::int128_t); - BENCHMARK_PREDICATE_TYPE(rocprim::uint128_t); - BENCHMARK_PREDICATE_TYPE(huge_float2); - - BENCHMARK_TWO_WAY_FLAG_TYPE(int, unsigned char); - BENCHMARK_TWO_WAY_FLAG_TYPE(float, unsigned char); - BENCHMARK_TWO_WAY_FLAG_TYPE(double, unsigned char); - BENCHMARK_TWO_WAY_FLAG_TYPE(uint8_t, uint8_t); - BENCHMARK_TWO_WAY_FLAG_TYPE(int8_t, int8_t); - BENCHMARK_TWO_WAY_FLAG_TYPE(rocprim::half, int8_t); - BENCHMARK_TWO_WAY_FLAG_TYPE(custom_double2, unsigned char); - BENCHMARK_TWO_WAY_FLAG_TYPE(rocprim::int128_t, int8_t); - BENCHMARK_TWO_WAY_FLAG_TYPE(rocprim::uint128_t, uint8_t); - BENCHMARK_TWO_WAY_FLAG_TYPE(huge_float2, uint8_t); - - BENCHMARK_TWO_WAY_PREDICATE_TYPE(int); - BENCHMARK_TWO_WAY_PREDICATE_TYPE(float); - BENCHMARK_TWO_WAY_PREDICATE_TYPE(double); - BENCHMARK_TWO_WAY_PREDICATE_TYPE(uint8_t); - BENCHMARK_TWO_WAY_PREDICATE_TYPE(int8_t); - BENCHMARK_TWO_WAY_PREDICATE_TYPE(rocprim::half); - BENCHMARK_TWO_WAY_PREDICATE_TYPE(custom_int_double); - BENCHMARK_TWO_WAY_PREDICATE_TYPE(rocprim::int128_t); - BENCHMARK_TWO_WAY_PREDICATE_TYPE(rocprim::uint128_t); - BENCHMARK_TWO_WAY_PREDICATE_TYPE(huge_float2); - - BENCHMARK_THREE_WAY_TYPE(int); - BENCHMARK_THREE_WAY_TYPE(float); - BENCHMARK_THREE_WAY_TYPE(double); - BENCHMARK_THREE_WAY_TYPE(uint8_t); - BENCHMARK_THREE_WAY_TYPE(int8_t); - BENCHMARK_THREE_WAY_TYPE(rocprim::half); - BENCHMARK_THREE_WAY_TYPE(custom_int_double); - BENCHMARK_THREE_WAY_TYPE(rocprim::int128_t); - BENCHMARK_THREE_WAY_TYPE(rocprim::uint128_t); - BENCHMARK_THREE_WAY_TYPE(huge_float2); + BENCHMARK_FLAG_TYPE(int, unsigned char) + BENCHMARK_FLAG_TYPE(float, unsigned char) + BENCHMARK_FLAG_TYPE(double, unsigned char) + BENCHMARK_FLAG_TYPE(uint8_t, uint8_t) + BENCHMARK_FLAG_TYPE(int8_t, int8_t) + BENCHMARK_FLAG_TYPE(rocprim::half, int8_t) + BENCHMARK_FLAG_TYPE(custom_double2, unsigned char) + BENCHMARK_FLAG_TYPE(rocprim::int128_t, int8_t) + BENCHMARK_FLAG_TYPE(rocprim::uint128_t, uint8_t) + BENCHMARK_FLAG_TYPE(huge_float2, uint8_t) + + BENCHMARK_PREDICATE_TYPE(int) + BENCHMARK_PREDICATE_TYPE(float) + BENCHMARK_PREDICATE_TYPE(double) + BENCHMARK_PREDICATE_TYPE(uint8_t) + BENCHMARK_PREDICATE_TYPE(int8_t) + BENCHMARK_PREDICATE_TYPE(rocprim::half) + BENCHMARK_PREDICATE_TYPE(custom_int_double) + BENCHMARK_PREDICATE_TYPE(rocprim::int128_t) + BENCHMARK_PREDICATE_TYPE(rocprim::uint128_t) + BENCHMARK_PREDICATE_TYPE(huge_float2) + + BENCHMARK_TWO_WAY_FLAG_TYPE(int, unsigned char) + BENCHMARK_TWO_WAY_FLAG_TYPE(float, unsigned char) + BENCHMARK_TWO_WAY_FLAG_TYPE(double, unsigned char) + BENCHMARK_TWO_WAY_FLAG_TYPE(uint8_t, uint8_t) + BENCHMARK_TWO_WAY_FLAG_TYPE(int8_t, int8_t) + BENCHMARK_TWO_WAY_FLAG_TYPE(rocprim::half, int8_t) + BENCHMARK_TWO_WAY_FLAG_TYPE(custom_double2, unsigned char) + BENCHMARK_TWO_WAY_FLAG_TYPE(rocprim::int128_t, int8_t) + BENCHMARK_TWO_WAY_FLAG_TYPE(rocprim::uint128_t, uint8_t) + BENCHMARK_TWO_WAY_FLAG_TYPE(huge_float2, uint8_t) + + BENCHMARK_TWO_WAY_PREDICATE_TYPE(int) + BENCHMARK_TWO_WAY_PREDICATE_TYPE(float) + BENCHMARK_TWO_WAY_PREDICATE_TYPE(double) + BENCHMARK_TWO_WAY_PREDICATE_TYPE(uint8_t) + BENCHMARK_TWO_WAY_PREDICATE_TYPE(int8_t) + BENCHMARK_TWO_WAY_PREDICATE_TYPE(rocprim::half) + BENCHMARK_TWO_WAY_PREDICATE_TYPE(custom_int_double) + BENCHMARK_TWO_WAY_PREDICATE_TYPE(rocprim::int128_t) + BENCHMARK_TWO_WAY_PREDICATE_TYPE(rocprim::uint128_t) + BENCHMARK_TWO_WAY_PREDICATE_TYPE(huge_float2) + + BENCHMARK_THREE_WAY_TYPE(int) + BENCHMARK_THREE_WAY_TYPE(float) + BENCHMARK_THREE_WAY_TYPE(double) + BENCHMARK_THREE_WAY_TYPE(uint8_t) + BENCHMARK_THREE_WAY_TYPE(int8_t) + BENCHMARK_THREE_WAY_TYPE(rocprim::half) + BENCHMARK_THREE_WAY_TYPE(custom_int_double) + BENCHMARK_THREE_WAY_TYPE(rocprim::int128_t) + BENCHMARK_THREE_WAY_TYPE(rocprim::uint128_t) + BENCHMARK_THREE_WAY_TYPE(huge_float2) #endif - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_partition.parallel.cpp.in b/benchmark/benchmark_device_partition.parallel.cpp.in index d06cdfabc..e4191fc15 100644 --- a/benchmark/benchmark_device_partition.parallel.cpp.in +++ b/benchmark/benchmark_device_partition.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -29,6 +29,6 @@ namespace { -auto benchmarks = config_autotune_register::create_bulk( +auto unused = benchmark_utils::executor::queue_autotune( device_partition_benchmark_generator<@DataType@, @BlockSize@>::create); } // namespace diff --git a/benchmark/benchmark_device_partition.parallel.hpp b/benchmark/benchmark_device_partition.parallel.hpp index 717bebdb2..e0d1dfd7d 100644 --- a/benchmark/benchmark_device_partition.parallel.hpp +++ b/benchmark/benchmark_device_partition.parallel.hpp @@ -24,9 +24,11 @@ #define ROCPRIM_BENCHMARK_DEVICE_PARTITION_PARALLEL_HPP_ #include "benchmark_utils.hpp" -#include "cmdparser.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" + +#include "cmdparser.hpp" #include @@ -118,14 +120,11 @@ inline const char* get_probability_name(partition_three_way_probability probabil return "invalid"; } -constexpr int warmup_iter = 5; -constexpr int batch_size = 10; - template -struct device_partition_flag_benchmark : public config_autotune_interface +struct device_partition_flag_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -137,11 +136,12 @@ struct device_partition_flag_benchmark : public config_autotune_interface + ",cfg:" + partition_config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - const hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(DataType); @@ -166,35 +166,20 @@ struct device_partition_flag_benchmark : public config_autotune_interface flags_0 = get_random_data01(size, get_probability(Probability), seed.get_1()); } - DataType* d_input{}; - HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); - FlagType* d_flags_0{}; - FlagType* d_flags_1{}; - FlagType* d_flags_2{}; - HIP_CHECK(hipMalloc(&d_flags_0, size * sizeof(*d_flags_0))); - HIP_CHECK( - hipMemcpy(d_flags_0, flags_0.data(), size * sizeof(*d_flags_0), hipMemcpyHostToDevice)); + common::device_ptr d_flags_0(flags_0); + common::device_ptr d_flags_1; + common::device_ptr d_flags_2; if(is_tuning) { - HIP_CHECK(hipMalloc(&d_flags_1, size * sizeof(*d_flags_1))); - HIP_CHECK(hipMemcpy(d_flags_1, - flags_1.data(), - size * sizeof(*d_flags_1), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMalloc(&d_flags_2, size * sizeof(*d_flags_2))); - HIP_CHECK(hipMemcpy(d_flags_2, - flags_2.data(), - size * sizeof(*d_flags_2), - hipMemcpyHostToDevice)); + d_flags_1.store(flags_1); + d_flags_2.store(flags_2); } - DataType* d_output{}; - HIP_CHECK(hipMalloc(&d_output, size * sizeof(*d_output))); + common::device_ptr d_output(size); - unsigned int* d_selected_count_output{}; - HIP_CHECK(hipMalloc(&d_selected_count_output, sizeof(*d_selected_count_output))); + common::device_ptr d_selected_count_output(1); const auto dispatch = [&](void* d_temp_storage, size_t& temp_storage_size_bytes) { @@ -202,69 +187,30 @@ struct device_partition_flag_benchmark : public config_autotune_interface { HIP_CHECK(rocprim::partition(d_temp_storage, temp_storage_size_bytes, - d_input, + d_input.get(), d_flags, - d_output, - d_selected_count_output, + d_output.get(), + d_selected_count_output.get(), size, stream)); }; - dispatch_flags(d_flags_0); + dispatch_flags(d_flags_0.get()); if(is_tuning) { - dispatch_flags(d_flags_1); - dispatch_flags(d_flags_2); + dispatch_flags(d_flags_1.get()); + dispatch_flags(d_flags_2.get()); } }; // Allocate temporary storage memory size_t temp_storage_size_bytes{}; dispatch(nullptr, temp_storage_size_bytes); - void* d_temp_storage{}; - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); + common::device_ptr d_temp_storage(temp_storage_size_bytes); - for(int i = 0; i < warmup_iter; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipDeviceSynchronize()); + state.run([&] { dispatch(d_temp_storage.get(), temp_storage_size_bytes); }); - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - HIP_CHECK(hipEventRecord(start, stream)); - for(int i = 0; i < batch_size; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds{}; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_input)); - if(is_tuning) - { - HIP_CHECK(hipFree(d_flags_2)); - HIP_CHECK(hipFree(d_flags_1)); - } - HIP_CHECK(hipFree(d_flags_0)); - HIP_CHECK(hipFree(d_output)); - HIP_CHECK(hipFree(d_selected_count_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(DataType)); } static constexpr bool is_tuning = Probability == partition_probability::tuning; @@ -273,7 +219,7 @@ struct device_partition_flag_benchmark : public config_autotune_interface template -struct device_partition_predicate_benchmark : public config_autotune_interface +struct device_partition_predicate_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -284,11 +230,12 @@ struct device_partition_predicate_benchmark : public config_autotune_interface + ",cfg:" + partition_config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - const hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(DataType); @@ -298,15 +245,11 @@ struct device_partition_predicate_benchmark : public config_autotune_interface static_cast(126), seed.get_0()); - DataType* d_input{}; - HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); - DataType* d_output{}; - HIP_CHECK(hipMalloc(&d_output, size * sizeof(*d_output))); + common::device_ptr d_output(size); - unsigned int* d_selected_count_output{}; - HIP_CHECK(hipMalloc(&d_selected_count_output, sizeof(*d_selected_count_output))); + common::device_ptr d_selected_count_output(1); const auto dispatch = [&](void* d_temp_storage, size_t& temp_storage_size_bytes) { @@ -316,9 +259,9 @@ struct device_partition_predicate_benchmark : public config_autotune_interface { return value < static_cast(127 * probability); }; HIP_CHECK(rocprim::partition(d_temp_storage, temp_storage_size_bytes, - d_input, - d_output, - d_selected_count_output, + d_input.get(), + d_output.get(), + d_selected_count_output.get(), size, predicate, stream)); @@ -339,45 +282,12 @@ struct device_partition_predicate_benchmark : public config_autotune_interface // Allocate temporary storage memory size_t temp_storage_size_bytes{}; dispatch(nullptr, temp_storage_size_bytes); - void* d_temp_storage{}; - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - for(int i = 0; i < warmup_iter; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } + common::device_ptr d_temp_storage(temp_storage_size_bytes); HIP_CHECK(hipDeviceSynchronize()); - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - HIP_CHECK(hipEventRecord(start, stream)); - for(int i = 0; i < batch_size; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); + state.run([&] { dispatch(d_temp_storage.get(), temp_storage_size_bytes); }); - float elapsed_mseconds{}; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); - HIP_CHECK(hipFree(d_selected_count_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(DataType)); } static constexpr bool is_tuning = Probability == partition_probability::tuning; @@ -387,7 +297,7 @@ template -struct device_partition_two_way_flag_benchmark : public config_autotune_interface +struct device_partition_two_way_flag_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -399,11 +309,12 @@ struct device_partition_two_way_flag_benchmark : public config_autotune_interfac + get_probability_name(Probability) + ",cfg:" + partition_config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - const hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(DataType); @@ -428,38 +339,22 @@ struct device_partition_two_way_flag_benchmark : public config_autotune_interfac flags_0 = get_random_data01(size, get_probability(Probability), seed.get_1()); } - DataType* d_input{}; - HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); - FlagType* d_flags_0{}; - FlagType* d_flags_1{}; - FlagType* d_flags_2{}; - HIP_CHECK(hipMalloc(&d_flags_0, size * sizeof(*d_flags_0))); - HIP_CHECK( - hipMemcpy(d_flags_0, flags_0.data(), size * sizeof(*d_flags_0), hipMemcpyHostToDevice)); + common::device_ptr d_flags_0(flags_0); + common::device_ptr d_flags_1; + common::device_ptr d_flags_2; if(is_tuning) { - HIP_CHECK(hipMalloc(&d_flags_1, size * sizeof(*d_flags_1))); - HIP_CHECK(hipMemcpy(d_flags_1, - flags_1.data(), - size * sizeof(*d_flags_1), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMalloc(&d_flags_2, size * sizeof(*d_flags_2))); - HIP_CHECK(hipMemcpy(d_flags_2, - flags_2.data(), - size * sizeof(*d_flags_2), - hipMemcpyHostToDevice)); + d_flags_1.store(flags_1); + d_flags_2.store(flags_2); } - DataType* d_output_selected{}; - HIP_CHECK(hipMalloc(&d_output_selected, size * sizeof(*d_output_selected))); + common::device_ptr d_output_selected(size); - DataType* d_output_rejected{}; - HIP_CHECK(hipMalloc(&d_output_rejected, size * sizeof(*d_output_rejected))); + common::device_ptr d_output_rejected(size); - unsigned int* d_selected_count_output{}; - HIP_CHECK(hipMalloc(&d_selected_count_output, sizeof(*d_selected_count_output))); + common::device_ptr d_selected_count_output(1); const auto dispatch = [&](void* d_temp_storage, size_t& temp_storage_size_bytes) { @@ -467,71 +362,31 @@ struct device_partition_two_way_flag_benchmark : public config_autotune_interfac { HIP_CHECK(rocprim::partition_two_way(d_temp_storage, temp_storage_size_bytes, - d_input, + d_input.get(), d_flags, - d_output_selected, - d_output_rejected, - d_selected_count_output, + d_output_selected.get(), + d_output_rejected.get(), + d_selected_count_output.get(), size, stream)); }; - dispatch_flags(d_flags_0); + dispatch_flags(d_flags_0.get()); if(is_tuning) { - dispatch_flags(d_flags_1); - dispatch_flags(d_flags_2); + dispatch_flags(d_flags_1.get()); + dispatch_flags(d_flags_2.get()); } }; // Allocate temporary storage memory size_t temp_storage_size_bytes = 0; dispatch(nullptr, temp_storage_size_bytes); - void* d_temp_storage = nullptr; - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); - - for(int i = 0; i < warmup_iter; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipDeviceSynchronize()); + common::device_ptr d_temp_storage(temp_storage_size_bytes); - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); + state.run([&] { dispatch(d_temp_storage.get(), temp_storage_size_bytes); }); - for(auto _ : state) - { - HIP_CHECK(hipEventRecord(start, stream)); - for(int i = 0; i < batch_size; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds{}; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_input)); - if(is_tuning) - { - HIP_CHECK(hipFree(d_flags_2)); - HIP_CHECK(hipFree(d_flags_1)); - } - HIP_CHECK(hipFree(d_flags_0)); - HIP_CHECK(hipFree(d_output_selected)); - HIP_CHECK(hipFree(d_output_rejected)); - HIP_CHECK(hipFree(d_selected_count_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(DataType)); } static constexpr bool is_tuning = Probability == partition_probability::tuning; @@ -540,7 +395,7 @@ struct device_partition_two_way_flag_benchmark : public config_autotune_interfac template -struct device_partition_two_way_predicate_benchmark : public config_autotune_interface +struct device_partition_two_way_predicate_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -551,11 +406,12 @@ struct device_partition_two_way_predicate_benchmark : public config_autotune_int + get_probability_name(Probability) + ",cfg:" + partition_config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - const hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(DataType); @@ -565,18 +421,13 @@ struct device_partition_two_way_predicate_benchmark : public config_autotune_int static_cast(126), seed.get_0()); - DataType* d_input; - HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); - DataType* d_output_selected; - HIP_CHECK(hipMalloc(&d_output_selected, size * sizeof(*d_output_selected))); + common::device_ptr d_output_selected(size); - DataType* d_output_rejected; - HIP_CHECK(hipMalloc(&d_output_rejected, size * sizeof(*d_output_selected))); + common::device_ptr d_output_rejected(size); - unsigned int* d_selected_count_output; - HIP_CHECK(hipMalloc(&d_selected_count_output, sizeof(*d_selected_count_output))); + common::device_ptr d_selected_count_output(1); const auto dispatch = [&](void* d_temp_storage, size_t& temp_storage_size_bytes) { @@ -586,10 +437,10 @@ struct device_partition_two_way_predicate_benchmark : public config_autotune_int { return value < static_cast(127 * probability); }; HIP_CHECK(rocprim::partition_two_way(d_temp_storage, temp_storage_size_bytes, - d_input, - d_output_selected, - d_output_rejected, - d_selected_count_output, + d_input.get(), + d_output_selected.get(), + d_output_rejected.get(), + d_selected_count_output.get(), size, predicate, stream)); @@ -610,45 +461,11 @@ struct device_partition_two_way_predicate_benchmark : public config_autotune_int // Allocate temporary storage memory size_t temp_storage_size_bytes{}; dispatch(nullptr, temp_storage_size_bytes); - void* d_temp_storage{}; - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); + common::device_ptr d_temp_storage(temp_storage_size_bytes); - for(int i = 0; i < warmup_iter; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipDeviceSynchronize()); + state.run([&] { dispatch(d_temp_storage.get(), temp_storage_size_bytes); }); - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - HIP_CHECK(hipEventRecord(start, stream)); - for(int i = 0; i < batch_size; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds{}; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output_selected)); - HIP_CHECK(hipFree(d_output_rejected)); - HIP_CHECK(hipFree(d_selected_count_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(DataType)); } static constexpr bool is_tuning = Probability == partition_probability::tuning; @@ -657,7 +474,7 @@ struct device_partition_two_way_predicate_benchmark : public config_autotune_int template -struct device_partition_three_way_benchmark : public config_autotune_interface +struct device_partition_three_way_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -668,11 +485,12 @@ struct device_partition_three_way_benchmark : public config_autotune_interface + ",cfg:" + partition_config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - const hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(DataType); @@ -682,21 +500,15 @@ struct device_partition_three_way_benchmark : public config_autotune_interface static_cast(126), seed.get_0()); - DataType* d_input{}; - HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); - DataType* d_output_first{}; - HIP_CHECK(hipMalloc(&d_output_first, size * sizeof(*d_output_first))); + common::device_ptr d_output_first(size); - DataType* d_output_second{}; - HIP_CHECK(hipMalloc(&d_output_second, size * sizeof(*d_output_second))); + common::device_ptr d_output_second(size); - DataType* d_output_unselected{}; - HIP_CHECK(hipMalloc(&d_output_unselected, size * sizeof(*d_output_unselected))); + common::device_ptr d_output_unselected(size); - unsigned int* d_selected_count_output{}; - HIP_CHECK(hipMalloc(&d_selected_count_output, 2 * sizeof(*d_selected_count_output))); + common::device_ptr d_selected_count_output(2); const auto dispatch = [&](void* d_temp_storage, size_t& temp_storage_size_bytes) { @@ -711,11 +523,11 @@ struct device_partition_three_way_benchmark : public config_autotune_interface HIP_CHECK(rocprim::partition_three_way(d_temp_storage, temp_storage_size_bytes, - d_input, - d_output_first, - d_output_second, - d_output_unselected, - d_selected_count_output, + d_input.get(), + d_output_first.get(), + d_output_second.get(), + d_output_unselected.get(), + d_selected_count_output.get(), size, predicate_one, predicate_two, @@ -749,46 +561,11 @@ struct device_partition_three_way_benchmark : public config_autotune_interface // Allocate temporary storage memory size_t temp_storage_size_bytes{}; dispatch(nullptr, temp_storage_size_bytes); - void* d_temp_storage{}; - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); - - for(int i = 0; i < warmup_iter; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipDeviceSynchronize()); - - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - HIP_CHECK(hipEventRecord(start, stream)); - for(int i = 0; i < batch_size; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds{}; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + common::device_ptr d_temp_storage(temp_storage_size_bytes); - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); - state.SetItemsProcessed(state.iterations() * batch_size * size); + state.run([&] { dispatch(d_temp_storage.get(), temp_storage_size_bytes); }); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output_first)); - HIP_CHECK(hipFree(d_output_second)); - HIP_CHECK(hipFree(d_output_unselected)); - HIP_CHECK(hipFree(d_selected_count_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(DataType)); } static constexpr bool is_tuning = Probability == partition_three_way_probability::tuning; @@ -802,7 +579,7 @@ struct device_partition_benchmark_generator template struct create_ipt { - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { using config = rocprim::select_config; storage.emplace_back( @@ -818,7 +595,7 @@ struct device_partition_benchmark_generator } }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { static constexpr int max_items_per_thread = std::min(64 / sizeof(DataType), size_t{32}); static_for_each, create_ipt>(storage); diff --git a/benchmark/benchmark_device_radix_sort.cpp b/benchmark/benchmark_device_radix_sort.cpp index bef154ba6..2c54a5a14 100644 --- a/benchmark/benchmark_device_radix_sort.cpp +++ b/benchmark/benchmark_device_radix_sort.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -22,11 +22,6 @@ #include "benchmark_device_radix_sort.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - -// Google Benchmark -#include // HIP API #include @@ -35,60 +30,48 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif +#define CREATE_RADIX_SORT_BENCHMARK(...) \ + executor.queue_instance(device_radix_sort_benchmark<__VA_ARGS__>()); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); + using custom_key = common::custom_type; + CREATE_RADIX_SORT_BENCHMARK(int) + CREATE_RADIX_SORT_BENCHMARK(float) + CREATE_RADIX_SORT_BENCHMARK(long long) + CREATE_RADIX_SORT_BENCHMARK(int8_t) + CREATE_RADIX_SORT_BENCHMARK(uint8_t) + CREATE_RADIX_SORT_BENCHMARK(rocprim::half) + CREATE_RADIX_SORT_BENCHMARK(short) + CREATE_RADIX_SORT_BENCHMARK(custom_key) + CREATE_RADIX_SORT_BENCHMARK(rocprim::int128_t) + CREATE_RADIX_SORT_BENCHMARK(rocprim::uint128_t) - // Add benchmarks - std::vector benchmarks = {}; - add_sort_keys_benchmarks(benchmarks, bytes, seed, stream); - add_sort_pairs_benchmarks(benchmarks, bytes, seed, stream); + using custom_float2 = common::custom_type; + using custom_double2 = common::custom_type; + using custom_key = common::custom_type; - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } + CREATE_RADIX_SORT_BENCHMARK(int, float) + CREATE_RADIX_SORT_BENCHMARK(int, double) + CREATE_RADIX_SORT_BENCHMARK(int, float2) + CREATE_RADIX_SORT_BENCHMARK(int, custom_float2) + CREATE_RADIX_SORT_BENCHMARK(int, double2) + CREATE_RADIX_SORT_BENCHMARK(int, custom_double2) - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + CREATE_RADIX_SORT_BENCHMARK(long long, float) + CREATE_RADIX_SORT_BENCHMARK(long long, double) + CREATE_RADIX_SORT_BENCHMARK(long long, float2) + CREATE_RADIX_SORT_BENCHMARK(long long, custom_float2) + CREATE_RADIX_SORT_BENCHMARK(long long, double2) + CREATE_RADIX_SORT_BENCHMARK(long long, custom_double2) + CREATE_RADIX_SORT_BENCHMARK(int8_t, int8_t) + CREATE_RADIX_SORT_BENCHMARK(uint8_t, uint8_t) + CREATE_RADIX_SORT_BENCHMARK(rocprim::half, rocprim::half) + CREATE_RADIX_SORT_BENCHMARK(custom_key, double) + CREATE_RADIX_SORT_BENCHMARK(rocprim::int128_t, rocprim::int128_t) + CREATE_RADIX_SORT_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t) - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_radix_sort.hpp b/benchmark/benchmark_device_radix_sort.hpp index 30d865d06..af4891010 100644 --- a/benchmark/benchmark_device_radix_sort.hpp +++ b/benchmark/benchmark_device_radix_sort.hpp @@ -45,12 +45,10 @@ #include #include -namespace rp = rocprim; - template -struct device_radix_sort_benchmark : public config_autotune_interface +struct device_radix_sort_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -59,17 +57,15 @@ struct device_radix_sort_benchmark : public config_autotune_interface + ",value_type:" + std::string(Traits::name()) + ",cfg: default_config}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - // keys benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const + auto do_run(benchmark_utils::state&& state) const -> std::enable_if_t::value, void> { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; // Calculate the number of elements @@ -105,31 +101,8 @@ struct device_radix_sort_benchmark : public config_autotune_interface HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(invoke_radix_sort(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - static_cast(nullptr), - static_cast(nullptr), - size, - stream)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { HIP_CHECK(invoke_radix_sort(d_temporary_storage, temporary_storage_bytes, @@ -139,23 +112,9 @@ struct device_radix_sort_benchmark : public config_autotune_interface static_cast(nullptr), size, stream)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } + }); - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); + state.set_throughput(size, sizeof(key_type)); HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_keys_input)); @@ -164,12 +123,13 @@ struct device_radix_sort_benchmark : public config_autotune_interface // pairs benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const + auto do_run(benchmark_utils::state&& state) const -> std::enable_if_t::value, void> { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; using value_type = Value; @@ -220,31 +180,8 @@ struct device_radix_sort_benchmark : public config_autotune_interface HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(invoke_radix_sort(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - stream)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { HIP_CHECK(invoke_radix_sort(d_temporary_storage, temporary_storage_bytes, @@ -254,24 +191,9 @@ struct device_radix_sort_benchmark : public config_autotune_interface d_values_output, size, stream)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } + }); - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size - * (sizeof(key_type) + sizeof(value_type))); - state.SetItemsProcessed(state.iterations() * batch_size * size); + state.set_throughput(size, sizeof(key_type) + sizeof(value_type)); HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_keys_input)); @@ -280,12 +202,9 @@ struct device_radix_sort_benchmark : public config_autotune_interface HIP_CHECK(hipFree(d_values_output)); } - void run(benchmark::State& state, - size_t size, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { - do_run(state, size, seed, stream); + do_run(std::forward(state)); } private: @@ -299,19 +218,19 @@ struct device_radix_sort_benchmark : public config_autotune_interface size_t size, hipStream_t stream) -> std::enable_if_t::value - && std::is_same::value, + && std::is_same::value, hipError_t> { (void)values_input; (void)values_output; - return rp::radix_sort_keys(d_temporary_storage, - temp_storage_bytes, - keys_input, - keys_output, - size, - 0, - sizeof(K) * 8, - stream); + return rocprim::radix_sort_keys(d_temporary_storage, + temp_storage_bytes, + keys_input, + keys_output, + size, + 0, + sizeof(K) * 8, + stream); } template @@ -324,18 +243,18 @@ struct device_radix_sort_benchmark : public config_autotune_interface size_t size, hipStream_t stream) -> std::enable_if_t::value - && std::is_same::value, + && std::is_same::value, hipError_t> { (void)values_input; (void)values_output; - return rp::radix_sort_keys(d_temporary_storage, - temp_storage_bytes, - keys_input, - keys_output, - size, - custom_type_decomposer{}, - stream); + return rocprim::radix_sort_keys(d_temporary_storage, + temp_storage_bytes, + keys_input, + keys_output, + size, + custom_type_decomposer{}, + stream); } template @@ -348,19 +267,19 @@ struct device_radix_sort_benchmark : public config_autotune_interface size_t size, hipStream_t stream) -> std::enable_if_t::value - && !std::is_same::value, + && !std::is_same::value, hipError_t> { - return rp::radix_sort_pairs(d_temporary_storage, - temp_storage_bytes, - keys_input, - keys_output, - values_input, - values_output, - size, - 0, - sizeof(K) * 8, - stream); + return rocprim::radix_sort_pairs(d_temporary_storage, + temp_storage_bytes, + keys_input, + keys_output, + values_input, + values_output, + size, + 0, + sizeof(K) * 8, + stream); } template @@ -373,73 +292,19 @@ struct device_radix_sort_benchmark : public config_autotune_interface size_t size, hipStream_t stream) -> std::enable_if_t::value - && !std::is_same::value, + && !std::is_same::value, hipError_t> { - return rp::radix_sort_pairs(d_temporary_storage, - temp_storage_bytes, - keys_input, - keys_output, - values_input, - values_output, - size, - custom_type_decomposer{}, - stream); + return rocprim::radix_sort_pairs(d_temporary_storage, + temp_storage_bytes, + keys_input, + keys_output, + values_input, + values_output, + size, + custom_type_decomposer{}, + stream); } }; -#define CREATE_RADIX_SORT_BENCHMARK(...) \ - { \ - const device_radix_sort_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -inline void add_sort_keys_benchmarks(std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) -{ - using custom_key = common::custom_type; - CREATE_RADIX_SORT_BENCHMARK(int) - CREATE_RADIX_SORT_BENCHMARK(float) - CREATE_RADIX_SORT_BENCHMARK(long long) - CREATE_RADIX_SORT_BENCHMARK(int8_t) - CREATE_RADIX_SORT_BENCHMARK(uint8_t) - CREATE_RADIX_SORT_BENCHMARK(rocprim::half) - CREATE_RADIX_SORT_BENCHMARK(short) - CREATE_RADIX_SORT_BENCHMARK(custom_key) - CREATE_RADIX_SORT_BENCHMARK(rocprim::int128_t) - CREATE_RADIX_SORT_BENCHMARK(rocprim::uint128_t) -} - -inline void add_sort_pairs_benchmarks(std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) -{ - using custom_float2 = common::custom_type; - using custom_double2 = common::custom_type; - using custom_key = common::custom_type; - - CREATE_RADIX_SORT_BENCHMARK(int, float) - CREATE_RADIX_SORT_BENCHMARK(int, double) - CREATE_RADIX_SORT_BENCHMARK(int, float2) - CREATE_RADIX_SORT_BENCHMARK(int, custom_float2) - CREATE_RADIX_SORT_BENCHMARK(int, double2) - CREATE_RADIX_SORT_BENCHMARK(int, custom_double2) - - CREATE_RADIX_SORT_BENCHMARK(long long, float) - CREATE_RADIX_SORT_BENCHMARK(long long, double) - CREATE_RADIX_SORT_BENCHMARK(long long, float2) - CREATE_RADIX_SORT_BENCHMARK(long long, custom_float2) - CREATE_RADIX_SORT_BENCHMARK(long long, double2) - CREATE_RADIX_SORT_BENCHMARK(long long, custom_double2) - CREATE_RADIX_SORT_BENCHMARK(int8_t, int8_t) - CREATE_RADIX_SORT_BENCHMARK(uint8_t, uint8_t) - CREATE_RADIX_SORT_BENCHMARK(rocprim::half, rocprim::half) - CREATE_RADIX_SORT_BENCHMARK(custom_key, double) - CREATE_RADIX_SORT_BENCHMARK(rocprim::int128_t, rocprim::int128_t) - CREATE_RADIX_SORT_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t) -} - #endif // ROCPRIM_BENCHMARK_DEVICE_RADIX_SORT_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_radix_sort_block_sort.cpp b/benchmark/benchmark_device_radix_sort_block_sort.cpp index 15b04b60e..68b4d1d33 100644 --- a/benchmark/benchmark_device_radix_sort_block_sort.cpp +++ b/benchmark/benchmark_device_radix_sort_block_sort.cpp @@ -20,12 +20,6 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -// CmdParser -#include "cmdparser.hpp" - -// Google Benchmark -#include - // HIP API #include @@ -47,67 +41,14 @@ #include #endif -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_BENCHMARK(...) \ - { \ - const device_radix_sort_block_sort_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } +#define CREATE_BENCHMARK(...) \ + executor.queue_instance(device_radix_sort_block_sort_benchmark<__VA_ARGS__>()); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else // BENCHMARK_CONFIG_TUNING +#ifndef BENCHMARK_CONFIG_TUNING CREATE_BENCHMARK(int) CREATE_BENCHMARK(long long) CREATE_BENCHMARK(int8_t) @@ -132,25 +73,7 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(long long, custom_double2) CREATE_BENCHMARK(rocprim::int128_t, rocprim::int128_t) CREATE_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t) -#endif // BENCHMARK_CONFIG_TUNING - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } +#endif - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_radix_sort_block_sort.parallel.cpp.in b/benchmark/benchmark_device_radix_sort_block_sort.parallel.cpp.in index 2e7c45ccc..f9d4adbd0 100644 --- a/benchmark/benchmark_device_radix_sort_block_sort.parallel.cpp.in +++ b/benchmark/benchmark_device_radix_sort_block_sort.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,6 @@ #include namespace { - auto benchmarks = config_autotune_register::create_bulk( + auto unused = benchmark_utils::executor::queue_autotune( device_radix_sort_block_sort_benchmark_generator<@BlockSize@, @KeyType@, @ValueType@>::create); } diff --git a/benchmark/benchmark_device_radix_sort_block_sort.parallel.hpp b/benchmark/benchmark_device_radix_sort_block_sort.parallel.hpp index f5b96a82f..3fa500f7e 100644 --- a/benchmark/benchmark_device_radix_sort_block_sort.parallel.hpp +++ b/benchmark/benchmark_device_radix_sort_block_sort.parallel.hpp @@ -26,6 +26,7 @@ #include "benchmark_utils.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -62,7 +63,7 @@ inline std::string config_name() template -struct device_radix_sort_block_sort_benchmark : public config_autotune_interface +struct device_radix_sort_block_sort_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -72,17 +73,15 @@ struct device_radix_sort_block_sort_benchmark : public config_autotune_interface + ",cfg:" + config_name() + "}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - // keys benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const -> + auto do_run(benchmark_utils::state&& state) const -> typename std::enable_if::value, void>::type { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; // Calculate the number of elements @@ -95,50 +94,18 @@ struct device_radix_sort_block_sort_benchmark : public config_autotune_interface common::generate_limits::max(), seed.get_0()); - key_type* d_keys_input; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input), size * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_output), size * sizeof(key_type))); - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice)); + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_keys_output(size); rocprim::empty_type* values_ptr = nullptr; unsigned int items_per_block; - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK((rocprim::detail::radix_sort_block_sort( - d_keys_input, - d_keys_output, - values_ptr, - values_ptr, - size, - items_per_block, - rocprim::identity_decomposer{}, - 0, - sizeof(key_type) * 8, - stream, - false))); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { HIP_CHECK((rocprim::detail::radix_sort_block_sort( - d_keys_input, - d_keys_output, + d_keys_input.get(), + d_keys_output.get(), values_ptr, values_ptr, size, @@ -148,36 +115,20 @@ struct device_radix_sort_block_sort_benchmark : public config_autotune_interface sizeof(key_type) * 8, stream, false))); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } + }); - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); + state.set_throughput(size, sizeof(key_type)); } // pairs benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const -> + auto do_run(benchmark_utils::state&& state) const -> typename std::enable_if::value, void>::type { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; using value_type = Value; @@ -197,63 +148,24 @@ struct device_radix_sort_block_sort_benchmark : public config_autotune_interface values_input[i] = value_type(i); } - key_type* d_keys_input; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input), size * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_output), size * sizeof(key_type))); - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice)); - - value_type* d_values_input; - value_type* d_values_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_input), size * sizeof(value_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_output), size * sizeof(value_type))); - HIP_CHECK(hipMemcpy(d_values_input, - values_input.data(), - size * sizeof(value_type), - hipMemcpyHostToDevice)); + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_keys_output(size); + + common::device_ptr d_values_input(values_input); + common::device_ptr d_values_output(size); unsigned int items_per_block; HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK((rocprim::detail::radix_sort_block_sort( - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - items_per_block, - rocprim::identity_decomposer{}, - 0, - sizeof(key_type) * 8, - stream, - false))); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { HIP_CHECK((rocprim::detail::radix_sort_block_sort( - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, + d_keys_input.get(), + d_keys_output.get(), + d_values_input.get(), + d_values_output.get(), size, items_per_block, rocprim::identity_decomposer{}, @@ -261,37 +173,14 @@ struct device_radix_sort_block_sort_benchmark : public config_autotune_interface sizeof(key_type) * 8, stream, false))); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size - * (sizeof(key_type) + sizeof(value_type))); - state.SetItemsProcessed(state.iterations() * batch_size * size); + }); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); - HIP_CHECK(hipFree(d_values_input)); - HIP_CHECK(hipFree(d_values_output)); + state.set_throughput(size, sizeof(key_type) + sizeof(value_type)); } - void run(benchmark::State& state, - size_t size, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { - do_run(state, size, seed, stream); + do_run(std::forward(state)); } }; @@ -303,7 +192,7 @@ struct device_radix_sort_block_sort_benchmark_generator { using generated_config = rocprim::kernel_config; - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { storage.emplace_back( std::make_unique< @@ -311,7 +200,7 @@ struct device_radix_sort_block_sort_benchmark_generator } }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { // Sort_items_per_block must be equal or larger than radix_items_per_block, so make // the items_per_thread at least as large so the sort_items_per_block diff --git a/benchmark/benchmark_device_radix_sort_onesweep.cpp b/benchmark/benchmark_device_radix_sort_onesweep.cpp index 3041cd5ea..a10013196 100644 --- a/benchmark/benchmark_device_radix_sort_onesweep.cpp +++ b/benchmark/benchmark_device_radix_sort_onesweep.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,12 +20,6 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -// CmdParser -#include "cmdparser.hpp" - -// Google Benchmark -#include - // HIP API #include @@ -36,82 +30,46 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif +#define CREATE_RADIX_SORT_BENCHMARK(...) \ + executor.queue_instance(device_radix_sort_onesweep_benchmark<__VA_ARGS__>()); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); +#ifndef BENCHMARK_CONFIG_TUNING + CREATE_RADIX_SORT_BENCHMARK(int) + CREATE_RADIX_SORT_BENCHMARK(float) + CREATE_RADIX_SORT_BENCHMARK(long long) + CREATE_RADIX_SORT_BENCHMARK(int8_t) + CREATE_RADIX_SORT_BENCHMARK(uint8_t) + CREATE_RADIX_SORT_BENCHMARK(rocprim::half) + CREATE_RADIX_SORT_BENCHMARK(short) + CREATE_RADIX_SORT_BENCHMARK(rocprim::int128_t) + CREATE_RADIX_SORT_BENCHMARK(rocprim::uint128_t) - // HIP - hipStream_t stream = 0; // default + using custom_float2 = common::custom_type; + using custom_double2 = common::custom_type; - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); + CREATE_RADIX_SORT_BENCHMARK(int, float) + CREATE_RADIX_SORT_BENCHMARK(int, double) + CREATE_RADIX_SORT_BENCHMARK(int, float2) + CREATE_RADIX_SORT_BENCHMARK(int, custom_float2) + CREATE_RADIX_SORT_BENCHMARK(int, double2) + CREATE_RADIX_SORT_BENCHMARK(int, custom_double2) - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else // BENCHMARK_CONFIG_TUNING - add_sort_keys_benchmarks(benchmarks, bytes, seed, stream); - add_sort_pairs_benchmarks(benchmarks, bytes, seed, stream); -#endif // BENCHMARK_CONFIG_TUNING - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + CREATE_RADIX_SORT_BENCHMARK(long long, float) + CREATE_RADIX_SORT_BENCHMARK(long long, double) + CREATE_RADIX_SORT_BENCHMARK(long long, float2) + CREATE_RADIX_SORT_BENCHMARK(long long, custom_float2) + CREATE_RADIX_SORT_BENCHMARK(long long, double2) + CREATE_RADIX_SORT_BENCHMARK(long long, custom_double2) + CREATE_RADIX_SORT_BENCHMARK(int8_t, int8_t) + CREATE_RADIX_SORT_BENCHMARK(uint8_t, uint8_t) + CREATE_RADIX_SORT_BENCHMARK(rocprim::half, rocprim::half) + CREATE_RADIX_SORT_BENCHMARK(rocprim::int128_t, rocprim::int128_t) + CREATE_RADIX_SORT_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t) +#endif - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_radix_sort_onesweep.parallel.cpp.in b/benchmark/benchmark_device_radix_sort_onesweep.parallel.cpp.in index 3560bf81e..6be2643c6 100644 --- a/benchmark/benchmark_device_radix_sort_onesweep.parallel.cpp.in +++ b/benchmark/benchmark_device_radix_sort_onesweep.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,6 @@ #include namespace { - auto benchmarks = config_autotune_register::create_bulk( + auto unused = benchmark_utils::executor::queue_autotune( device_radix_sort_onesweep_benchmark_generator<@BlockSize@, @RadixBits@, @KeyType@, @ValueType@>::create); } diff --git a/benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp b/benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp index c6813809c..d37fd3015 100644 --- a/benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp +++ b/benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp @@ -27,6 +27,7 @@ #include "../common/utils_custom_type.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -88,7 +89,7 @@ inline std::string config_name() template -struct device_radix_sort_onesweep_benchmark : public config_autotune_interface +struct device_radix_sort_onesweep_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -98,17 +99,15 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface + ",cfg:" + config_name() + "}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - // keys benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const -> + auto do_run(benchmark_utils::state&& state) const -> typename std::enable_if::value, void>::type { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; // Calculate the number of elements @@ -120,27 +119,20 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface common::generate_limits::max(), seed.get_0()); - key_type* d_keys_input; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input), size * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_output), size * sizeof(key_type))); - - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice)); + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_keys_output(size); - void* d_temporary_storage = nullptr; + common::device_ptr d_temporary_storage; size_t temporary_storage_bytes = 0; bool is_result_in_output = true; rocprim::empty_type* d_values_ptr = nullptr; HIP_CHECK(( - rocprim::detail::radix_sort_onesweep_impl(d_temporary_storage, + rocprim::detail::radix_sort_onesweep_impl(d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input, + d_keys_input.get(), nullptr, - d_keys_output, + d_keys_output.get(), d_values_ptr, nullptr, d_values_ptr, @@ -153,50 +145,18 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface false, false))); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK((rocprim::detail::radix_sort_onesweep_impl( - d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - nullptr, - d_keys_output, - d_values_ptr, - nullptr, - d_values_ptr, - size, - is_result_in_output, - rocprim::identity_decomposer{}, - 0, - sizeof(key_type) * 8, - stream, - false, - false))); - } + d_temporary_storage.resize(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { HIP_CHECK((rocprim::detail::radix_sort_onesweep_impl( - d_temporary_storage, + d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input, + d_keys_input.get(), nullptr, - d_keys_output, + d_keys_output.get(), d_values_ptr, nullptr, d_values_ptr, @@ -208,37 +168,20 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface stream, false, false))); - } + }); - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); + state.set_throughput(size, sizeof(key_type)); } // pairs benchmark template - auto do_run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const -> + auto do_run(benchmark_utils::state&& state) const -> typename std::enable_if::value, void>::type { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; using value_type = Value; @@ -257,37 +200,25 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface values_input[i] = value_type(i); } - key_type* d_keys_input; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input), size * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_output), size * sizeof(key_type))); - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice)); - - value_type* d_values_input; - value_type* d_values_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_input), size * sizeof(value_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_output), size * sizeof(value_type))); - HIP_CHECK(hipMemcpy(d_values_input, - values_input.data(), - size * sizeof(value_type), - hipMemcpyHostToDevice)); - - void* d_temporary_storage = nullptr; + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_keys_output(size); + + common::device_ptr d_values_input(values_input); + common::device_ptr d_values_output(size); + + common::device_ptr d_temporary_storage; size_t temporary_storage_bytes = 0; bool is_result_in_output = true; HIP_CHECK(( - rocprim::detail::radix_sort_onesweep_impl(d_temporary_storage, + rocprim::detail::radix_sort_onesweep_impl(d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input, + d_keys_input.get(), nullptr, - d_keys_output, - d_values_input, + d_keys_output.get(), + d_values_input.get(), nullptr, - d_values_output, + d_values_output.get(), size, is_result_in_output, rocprim::identity_decomposer{}, @@ -297,53 +228,21 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface false, false))); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + d_temporary_storage.resize(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK((rocprim::detail::radix_sort_onesweep_impl( - d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - nullptr, - d_keys_output, - d_values_input, - nullptr, - d_values_output, - size, - is_result_in_output, - rocprim::identity_decomposer{}, - 0, - sizeof(key_type) * 8, - stream, - false, - false))); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { HIP_CHECK((rocprim::detail::radix_sort_onesweep_impl( - d_temporary_storage, + d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input, + d_keys_input.get(), nullptr, - d_keys_output, - d_values_input, + d_keys_output.get(), + d_values_input.get(), nullptr, - d_values_output, + d_values_output.get(), size, is_result_in_output, rocprim::identity_decomposer{}, @@ -352,38 +251,14 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface stream, false, false))); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size - * (sizeof(key_type) + sizeof(value_type))); - state.SetItemsProcessed(state.iterations() * batch_size * size); + }); - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); - HIP_CHECK(hipFree(d_values_input)); - HIP_CHECK(hipFree(d_values_output)); + state.set_throughput(size, sizeof(key_type) + sizeof(value_type)); } - void run(benchmark::State& state, - size_t size, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { - do_run(state, size, seed, stream); + do_run(std::forward(state)); } }; @@ -398,7 +273,7 @@ struct device_radix_sort_onesweep_benchmark_generator template static constexpr bool is_buildable() { - // Calculation uses `rocprim::arch::wavefront::min_size()`, which is 64 on host side unless overridden. + // Calculation uses `rocprim::arch::wavefront::min_size()`, which is 32 on host side unless overridden. // However, this does not affect the total size of shared memory for the current configuration space. // Were the implementation to change, causing retuning, this needs to be re-evaluated and possibly taken into account. using sharedmem_storage = typename rocprim::detail::onesweep_iteration_helper< @@ -429,7 +304,7 @@ struct device_radix_sort_onesweep_benchmark_generator rocprim::kernel_config, RadixBits, RadixRankAlgorithm>; - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { storage.emplace_back( std::make_unique< @@ -442,11 +317,12 @@ struct device_radix_sort_onesweep_benchmark_generator RadixRankAlgorithm, std::enable_if_t<(!is_buildable())>> { - void operator()(std::vector>&) const {} + void operator()(std::vector>&) const {} }; template - static void create_algo(std::vector>& storage) + static void + create_algo(std::vector>& storage) { create_ipt<1u, RadixRankAlgorithm>()(storage); create_ipt<4u, RadixRankAlgorithm>()(storage); @@ -458,65 +334,13 @@ struct device_radix_sort_onesweep_benchmark_generator create_ipt<22u, RadixRankAlgorithm>()(storage); } - static void create(std::vector>& storage) + static void create(std::vector>& storage) { create_algo(storage); create_algo(storage); } }; -#else // BENCHMARK_CONFIG_TUNING - - #define CREATE_RADIX_SORT_BENCHMARK(...) \ - { \ - const device_radix_sort_onesweep_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -inline void add_sort_keys_benchmarks(std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) -{ - CREATE_RADIX_SORT_BENCHMARK(int) - CREATE_RADIX_SORT_BENCHMARK(float) - CREATE_RADIX_SORT_BENCHMARK(long long) - CREATE_RADIX_SORT_BENCHMARK(int8_t) - CREATE_RADIX_SORT_BENCHMARK(uint8_t) - CREATE_RADIX_SORT_BENCHMARK(rocprim::half) - CREATE_RADIX_SORT_BENCHMARK(short) - CREATE_RADIX_SORT_BENCHMARK(rocprim::int128_t) - CREATE_RADIX_SORT_BENCHMARK(rocprim::uint128_t) -} - -inline void add_sort_pairs_benchmarks(std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) -{ - using custom_float2 = common::custom_type; - using custom_double2 = common::custom_type; - - CREATE_RADIX_SORT_BENCHMARK(int, float) - CREATE_RADIX_SORT_BENCHMARK(int, double) - CREATE_RADIX_SORT_BENCHMARK(int, float2) - CREATE_RADIX_SORT_BENCHMARK(int, custom_float2) - CREATE_RADIX_SORT_BENCHMARK(int, double2) - CREATE_RADIX_SORT_BENCHMARK(int, custom_double2) - - CREATE_RADIX_SORT_BENCHMARK(long long, float) - CREATE_RADIX_SORT_BENCHMARK(long long, double) - CREATE_RADIX_SORT_BENCHMARK(long long, float2) - CREATE_RADIX_SORT_BENCHMARK(long long, custom_float2) - CREATE_RADIX_SORT_BENCHMARK(long long, double2) - CREATE_RADIX_SORT_BENCHMARK(long long, custom_double2) - CREATE_RADIX_SORT_BENCHMARK(int8_t, int8_t) - CREATE_RADIX_SORT_BENCHMARK(uint8_t, uint8_t) - CREATE_RADIX_SORT_BENCHMARK(rocprim::half, rocprim::half) - CREATE_RADIX_SORT_BENCHMARK(rocprim::int128_t, rocprim::int128_t) - CREATE_RADIX_SORT_BENCHMARK(rocprim::uint128_t, rocprim::uint128_t) -} - #endif // BENCHMARK_CONFIG_TUNING #endif // ROCPRIM_BENCHMARK_DEVICE_RADIX_SORT_ONESWEEP_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_reduce.cpp b/benchmark/benchmark_device_reduce.cpp index 2ea1cc741..b6bfe80d2 100644 --- a/benchmark/benchmark_device_reduce.cpp +++ b/benchmark/benchmark_device_reduce.cpp @@ -22,16 +22,11 @@ #include "benchmark_device_reduce.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #ifndef BENCHMARK_CONFIG_TUNING #include "../common/utils_custom_type.hpp" #endif -// Google Benchmark -#include - // HIP API #include @@ -47,67 +42,14 @@ #include #endif -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; -#endif - -#define CREATE_BENCHMARK(T, REDUCE_OP) \ - { \ - const device_reduce_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } +#define CREATE_BENCHMARK(T, REDUCE_OP) \ + executor.queue_instance(device_reduce_benchmark()); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default + benchmark_utils::executor executor(argc, argv, 512 * benchmark_utils::MiB, 10, 5); - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else +#ifndef BENCHMARK_CONFIG_TUNING using custom_float2 = common::custom_type; using custom_double2 = common::custom_type; @@ -128,24 +70,5 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(rocprim::uint128_t, rocprim::plus) #endif - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_reduce.parallel.cpp.in b/benchmark/benchmark_device_reduce.parallel.cpp.in index a29fd2721..526dd8cd4 100644 --- a/benchmark/benchmark_device_reduce.parallel.cpp.in +++ b/benchmark/benchmark_device_reduce.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -31,7 +31,7 @@ #include namespace { - auto benchmark = config_autotune_register::create, rocprim::reduce_config<@BlockSize@u, @ItemsPerThread@u, rocprim::block_reduce_algorithm::using_warp_reduce>>>(); } diff --git a/benchmark/benchmark_device_reduce.parallel.hpp b/benchmark/benchmark_device_reduce.parallel.hpp index 9b0d4ef7e..fe6eded1b 100644 --- a/benchmark/benchmark_device_reduce.parallel.hpp +++ b/benchmark/benchmark_device_reduce.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,8 @@ #include "benchmark_utils.hpp" +#include "../common/utils_device_ptr.hpp" + // Google Benchmark #include @@ -74,7 +76,7 @@ inline std::string config_name() template, typename Config = rocprim::default_config> -struct device_reduce_benchmark : public config_autotune_interface +struct device_reduce_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -83,14 +85,12 @@ struct device_reduce_benchmark : public config_autotune_interface + ",cfg:" + config_name() + "}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(T); @@ -99,84 +99,38 @@ struct device_reduce_benchmark : public config_autotune_interface std::vector input = get_random_data(size, random_range.first, random_range.second, seed.get_0()); - T* d_input; - T* d_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(T))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), sizeof(T))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); + common::device_ptr d_output(1); HIP_CHECK(hipDeviceSynchronize()); // Allocate temporary storage memory size_t temp_storage_size_bytes; - void* d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK(rocprim::reduce(d_temp_storage, + HIP_CHECK(rocprim::reduce(nullptr, temp_storage_size_bytes, - d_input, - d_output, + d_input.get(), + d_output.get(), T(), size, reduce_op, stream)); - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); + common::device_ptr d_temp_storage(temp_storage_size_bytes); HIP_CHECK(hipDeviceSynchronize()); - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::reduce(d_temp_storage, - temp_storage_size_bytes, - d_input, - d_output, - T(), - size, - reduce_op, - stream)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::reduce(d_temp_storage, + HIP_CHECK(rocprim::reduce(d_temp_storage.get(), temp_storage_size_bytes, - d_input, - d_output, + d_input.get(), + d_output.get(), T(), size, reduce_op, stream)); - } - HIP_CHECK(hipStreamSynchronize(stream)); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * batch_size * size); + }); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(T)); } }; diff --git a/benchmark/benchmark_device_reduce_by_key.cpp b/benchmark/benchmark_device_reduce_by_key.cpp index 82772be17..b7a8fa8b4 100644 --- a/benchmark/benchmark_device_reduce_by_key.cpp +++ b/benchmark/benchmark_device_reduce_by_key.cpp @@ -22,156 +22,14 @@ #include "benchmark_device_reduce_by_key.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - -#ifndef BENCHMARK_CONFIG_TUNING - #include "../common/utils_custom_type.hpp" -#endif - -// Google Benchmark -#include - -// HIP API -#include - -#ifndef BENCHMARK_CONFIG_TUNING - #include -#endif - -#include -#include -#include -#ifndef BENCHMARK_CONFIG_TUNING - #include -#endif - -#ifndef DEFAULT_BYTES -constexpr size_t DEFAULT_BYTES = size_t{2} << 30; // 2 GiB -#endif - -#define CREATE_BENCHMARK(KEY, VALUE, MAX_SEGMENT_LENGTH) \ - { \ - const device_reduce_by_key_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ - } - -#define CREATE_BENCHMARK_TYPE(KEY, VALUE) \ - CREATE_BENCHMARK(KEY, VALUE, 10); \ - CREATE_BENCHMARK(KEY, VALUE, 1000) - -// some of the tuned types -#define CREATE_BENCHMARK_TYPES(KEY) \ - CREATE_BENCHMARK_TYPE(KEY, int8_t); \ - CREATE_BENCHMARK_TYPE(KEY, rocprim::half); \ - CREATE_BENCHMARK_TYPE(KEY, int32_t); \ - CREATE_BENCHMARK_TYPE(KEY, rocprim::int128_t); \ - CREATE_BENCHMARK_TYPE(KEY, rocprim::uint128_t); \ - CREATE_BENCHMARK_TYPE(KEY, float); \ - CREATE_BENCHMARK_TYPE(KEY, double) - -// all of the tuned types -#define CREATE_BENCHMARK_TYPE_TUNING(KEY) \ - CREATE_BENCHMARK_TYPE(KEY, int8_t); \ - CREATE_BENCHMARK_TYPE(KEY, int16_t); \ - CREATE_BENCHMARK_TYPE(KEY, int32_t); \ - CREATE_BENCHMARK_TYPE(KEY, int64_t); \ - CREATE_BENCHMARK_TYPE(KEY, rocprim::int128_t); \ - CREATE_BENCHMARK_TYPE(KEY, rocprim::uint128_t); \ - CREATE_BENCHMARK_TYPE(KEY, rocprim::half); \ - CREATE_BENCHMARK_TYPE(KEY, float); \ - CREATE_BENCHMARK_TYPE(KEY, double) int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); - benchmark::AddCustomContext("seed", seed_type); + benchmark_utils::executor executor(argc, argv, 2 * benchmark_utils::GiB, 10, 5); - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - size, - seed, - stream); -#else - // tuned types - CREATE_BENCHMARK_TYPES(int8_t); - CREATE_BENCHMARK_TYPES(int16_t); - CREATE_BENCHMARK_TYPE_TUNING(int32_t); - CREATE_BENCHMARK_TYPE_TUNING(int64_t); - CREATE_BENCHMARK_TYPES(rocprim::half); - CREATE_BENCHMARK_TYPES(float); - CREATE_BENCHMARK_TYPES(double); - CREATE_BENCHMARK_TYPES(rocprim::int128_t); - CREATE_BENCHMARK_TYPES(rocprim::uint128_t); - - // custom types - using custom_float2 = common::custom_type; - using custom_double2 = common::custom_type; - - CREATE_BENCHMARK_TYPE(int, custom_float2); - CREATE_BENCHMARK_TYPE(int, custom_double2); - - CREATE_BENCHMARK_TYPE(long long, custom_float2); - CREATE_BENCHMARK_TYPE(long long, custom_double2); +#ifndef BENCHMARK_CONFIG_TUNING + add_benchmarks(executor); #endif - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_reduce_by_key.parallel.cpp.in b/benchmark/benchmark_device_reduce_by_key.parallel.cpp.in index e43d984f3..028a31699 100644 --- a/benchmark/benchmark_device_reduce_by_key.parallel.cpp.in +++ b/benchmark/benchmark_device_reduce_by_key.parallel.cpp.in @@ -28,6 +28,6 @@ #include namespace { - auto benchmarks = config_autotune_register::create_bulk( + auto unused = benchmark_utils::executor::queue_autotune( device_reduce_by_key_benchmark_generator<@KeyType@, @ValueType@, @BlockSize@>::create); } diff --git a/benchmark/benchmark_device_reduce_by_key.parallel.hpp b/benchmark/benchmark_device_reduce_by_key.parallel.hpp index d9e8a6be9..710ff66a9 100644 --- a/benchmark/benchmark_device_reduce_by_key.parallel.hpp +++ b/benchmark/benchmark_device_reduce_by_key.parallel.hpp @@ -25,6 +25,8 @@ #include "benchmark_utils.hpp" +#include "../common/utils_device_ptr.hpp" + // Google Benchmark #include @@ -70,9 +72,9 @@ inline std::string config_name() template -struct device_reduce_by_key_benchmark : public config_autotune_interface + bool Deterministic, + typename Config = rocprim::default_config> +struct device_reduce_by_key_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -82,13 +84,12 @@ struct device_reduce_by_key_benchmark : public config_autotune_interface + std::to_string(MaxSegmentLength) + ",cfg:" + config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { - constexpr int batch_size = 10; - constexpr int warmup_size = 5; + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + constexpr std::array tuning_max_segment_lengths = {10, 1000}; constexpr int num_input_arrays = is_tuning ? tuning_max_segment_lengths.size() : 1; constexpr size_t item_size = sizeof(KeyType) + sizeof(ValueType); @@ -113,29 +114,17 @@ struct device_reduce_by_key_benchmark : public config_autotune_interface std::vector value_input(size); std::iota(value_input.begin(), value_input.end(), 0); - KeyType* d_key_inputs[num_input_arrays]; + common::device_ptr d_key_inputs[num_input_arrays]; for(int i = 0; i < num_input_arrays; ++i) { - HIP_CHECK(hipMalloc(&d_key_inputs[i], size * sizeof(*d_key_inputs[i]))); - HIP_CHECK(hipMemcpy(d_key_inputs[i], - key_inputs[i].data(), - size * sizeof(*d_key_inputs[i]), - hipMemcpyHostToDevice)); + d_key_inputs[i].store(key_inputs[i]); } - ValueType* d_value_input; - HIP_CHECK(hipMalloc(&d_value_input, size * sizeof(*d_value_input))); - HIP_CHECK(hipMemcpy(d_value_input, - value_input.data(), - size * sizeof(*d_value_input), - hipMemcpyHostToDevice)); + common::device_ptr d_value_input(value_input); - KeyType* d_unique_output; - ValueType* d_aggregates_output; - unsigned int* d_unique_count_output; - HIP_CHECK(hipMalloc(&d_unique_output, size * sizeof(*d_unique_output))); - HIP_CHECK(hipMalloc(&d_aggregates_output, size * sizeof(*d_aggregates_output))); - HIP_CHECK(hipMalloc(&d_unique_count_output, sizeof(*d_unique_count_output))); + common::device_ptr d_unique_output(size); + common::device_ptr d_aggregates_output(size); + common::device_ptr d_unique_count_output(1); rocprim::plus reduce_op; rocprim::equal_to key_compare_op; @@ -144,33 +133,34 @@ struct device_reduce_by_key_benchmark : public config_autotune_interface { const auto dispatch_input = [&](KeyType* d_key_input) { - if ROCPRIM_IF_CONSTEXPR(!Deterministic) + if constexpr(!Deterministic) { HIP_CHECK(rocprim::reduce_by_key(d_temp_storage, temp_storage_size_bytes, d_key_input, - d_value_input, + d_value_input.get(), size, - d_unique_output, - d_aggregates_output, - d_unique_count_output, + d_unique_output.get(), + d_aggregates_output.get(), + d_unique_count_output.get(), reduce_op, key_compare_op, stream)); } else { - HIP_CHECK(rocprim::deterministic_reduce_by_key(d_temp_storage, - temp_storage_size_bytes, - d_key_input, - d_value_input, - size, - d_unique_output, - d_aggregates_output, - d_unique_count_output, - reduce_op, - key_compare_op, - stream)); + HIP_CHECK( + rocprim::deterministic_reduce_by_key(d_temp_storage, + temp_storage_size_bytes, + d_key_input, + d_value_input.get(), + size, + d_unique_output.get(), + d_aggregates_output.get(), + d_unique_count_output.get(), + reduce_op, + key_compare_op, + stream)); } }; @@ -180,56 +170,18 @@ struct device_reduce_by_key_benchmark : public config_autotune_interface // generally larger segments perform better. for(int i = 0; i < num_input_arrays; ++i) { - dispatch_input(d_key_inputs[i]); + dispatch_input(d_key_inputs[i].get()); } }; // Allocate temporary storage memory size_t temp_storage_size_bytes{}; dispatch(nullptr, temp_storage_size_bytes); - void* d_temp_storage{}; - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); - - for(int i = 0; i < warmup_size; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipDeviceSynchronize()); + common::device_ptr d_temp_storage(temp_storage_size_bytes); - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); + state.run([&] { dispatch(d_temp_storage.get(), temp_storage_size_bytes); }); - for(auto _ : state) - { - HIP_CHECK(hipEventRecord(start, stream)); - for(int i = 0; i < batch_size; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds{}; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * item_size); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temp_storage)); - for(int i = 0; i < num_input_arrays; ++i) - { - HIP_CHECK(hipFree(d_key_inputs[i])); - } - HIP_CHECK(hipFree(d_value_input)); - HIP_CHECK(hipFree(d_unique_output)); - HIP_CHECK(hipFree(d_aggregates_output)); - HIP_CHECK(hipFree(d_unique_count_output)); + state.set_throughput(size, sizeof(KeyType) + sizeof(ValueType)); } static constexpr bool is_tuning = !std::is_same::value; @@ -243,7 +195,7 @@ struct device_reduce_by_key_benchmark_generator template struct create_ipt { - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { using config = rocprim::reduce_by_key_config>& storage) + static void create(std::vector>& storage) { static constexpr unsigned int max_items_per_thread = std::min( TUNING_SHARED_MEMORY_MAX / std::max(sizeof(KeyType), sizeof(ValueType)) / BlockSize - 1, @@ -270,4 +222,59 @@ struct device_reduce_by_key_benchmark_generator #endif // BENCHMARK_CONFIG_TUNING +#define CREATE_BENCHMARK(KEY, VALUE, MAX_SEGMENT_LENGTH) \ + executor.queue_instance( \ + device_reduce_by_key_benchmark()); + +#define CREATE_BENCHMARK_TYPE(KEY, VALUE) \ + CREATE_BENCHMARK(KEY, VALUE, 10) \ + CREATE_BENCHMARK(KEY, VALUE, 1000) + +// some of the tuned types +#define CREATE_BENCHMARK_TYPES(KEY) \ + CREATE_BENCHMARK_TYPE(KEY, int8_t) \ + CREATE_BENCHMARK_TYPE(KEY, rocprim::half) \ + CREATE_BENCHMARK_TYPE(KEY, int32_t) \ + CREATE_BENCHMARK_TYPE(KEY, rocprim::int128_t) \ + CREATE_BENCHMARK_TYPE(KEY, rocprim::uint128_t) \ + CREATE_BENCHMARK_TYPE(KEY, float) \ + CREATE_BENCHMARK_TYPE(KEY, double) + +// all of the tuned types +#define CREATE_BENCHMARK_TYPE_TUNING(KEY) \ + CREATE_BENCHMARK_TYPE(KEY, int8_t) \ + CREATE_BENCHMARK_TYPE(KEY, int16_t) \ + CREATE_BENCHMARK_TYPE(KEY, int32_t) \ + CREATE_BENCHMARK_TYPE(KEY, int64_t) \ + CREATE_BENCHMARK_TYPE(KEY, rocprim::int128_t) \ + CREATE_BENCHMARK_TYPE(KEY, rocprim::uint128_t) \ + CREATE_BENCHMARK_TYPE(KEY, rocprim::half) \ + CREATE_BENCHMARK_TYPE(KEY, float) \ + CREATE_BENCHMARK_TYPE(KEY, double) + +template +void add_benchmarks(benchmark_utils::executor& executor) +{ + // tuned types + CREATE_BENCHMARK_TYPES(int8_t) + CREATE_BENCHMARK_TYPES(int16_t) + CREATE_BENCHMARK_TYPE_TUNING(int32_t) + CREATE_BENCHMARK_TYPE_TUNING(int64_t) + CREATE_BENCHMARK_TYPES(rocprim::half) + CREATE_BENCHMARK_TYPES(float) + CREATE_BENCHMARK_TYPES(double) + CREATE_BENCHMARK_TYPES(rocprim::int128_t) + CREATE_BENCHMARK_TYPES(rocprim::uint128_t) + + // custom types + using custom_float2 = common::custom_type; + using custom_double2 = common::custom_type; + + CREATE_BENCHMARK_TYPE(int, custom_float2) + CREATE_BENCHMARK_TYPE(int, custom_double2) + + CREATE_BENCHMARK_TYPE(long long, custom_float2) + CREATE_BENCHMARK_TYPE(long long, custom_double2) +} + #endif // ROCPRIM_BENCHMARK_DEVICE_REDUCE_BY_KEY_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_reduce_by_key_deterministic.cpp b/benchmark/benchmark_device_reduce_by_key_deterministic.cpp index b7dff1a7c..8b9d07467 100644 --- a/benchmark/benchmark_device_reduce_by_key_deterministic.cpp +++ b/benchmark/benchmark_device_reduce_by_key_deterministic.cpp @@ -22,127 +22,12 @@ #include "benchmark_device_reduce_by_key.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - -#include "../common/utils_custom_type.hpp" - -// Google Benchmark -#include - -// HIP API -#include - -#include - -#include -#include -#include - -#ifndef DEFAULT_BYTES -constexpr size_t DEFAULT_BYTES = size_t{2} << 30; // 2 GiB -#endif - -#define CREATE_BENCHMARK(KEY, VALUE, MAX_SEGMENT_LENGTH) \ - { \ - const device_reduce_by_key_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ - } - -#define CREATE_BENCHMARK_TYPE(KEY, VALUE) \ - CREATE_BENCHMARK(KEY, VALUE, 10); \ - CREATE_BENCHMARK(KEY, VALUE, 1000) - -// some of the tuned types -#define CREATE_BENCHMARK_TYPES(KEY) \ - CREATE_BENCHMARK_TYPE(KEY, int8_t); \ - CREATE_BENCHMARK_TYPE(KEY, rocprim::half); \ - CREATE_BENCHMARK_TYPE(KEY, int32_t); \ - CREATE_BENCHMARK_TYPE(KEY, float); \ - CREATE_BENCHMARK_TYPE(KEY, double); \ - CREATE_BENCHMARK_TYPE(KEY, rocprim::int128_t); \ - CREATE_BENCHMARK_TYPE(KEY, rocprim::uint128_t) - -// all of the tuned types -#define CREATE_BENCHMARK_TYPE_TUNING(KEY) \ - CREATE_BENCHMARK_TYPE(KEY, int8_t); \ - CREATE_BENCHMARK_TYPE(KEY, int16_t); \ - CREATE_BENCHMARK_TYPE(KEY, int32_t); \ - CREATE_BENCHMARK_TYPE(KEY, int64_t); \ - CREATE_BENCHMARK_TYPE(KEY, rocprim::half); \ - CREATE_BENCHMARK_TYPE(KEY, float); \ - CREATE_BENCHMARK_TYPE(KEY, double); \ - CREATE_BENCHMARK_TYPE(KEY, rocprim::int128_t); \ - CREATE_BENCHMARK_TYPE(KEY, rocprim::uint128_t) int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; - // tuned types - CREATE_BENCHMARK_TYPES(int8_t); - CREATE_BENCHMARK_TYPES(int16_t); - CREATE_BENCHMARK_TYPE_TUNING(int32_t); - CREATE_BENCHMARK_TYPE_TUNING(int64_t); - CREATE_BENCHMARK_TYPES(rocprim::half); - CREATE_BENCHMARK_TYPES(float); - CREATE_BENCHMARK_TYPES(double); - CREATE_BENCHMARK_TYPES(rocprim::int128_t); - CREATE_BENCHMARK_TYPES(rocprim::uint128_t); - - // custom types - using custom_float2 = common::custom_type; - using custom_double2 = common::custom_type; - - CREATE_BENCHMARK_TYPE(int, custom_float2); - CREATE_BENCHMARK_TYPE(int, custom_double2); - - CREATE_BENCHMARK_TYPE(long long, custom_float2); - CREATE_BENCHMARK_TYPE(long long, custom_double2); - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } + benchmark_utils::executor executor(argc, argv, 2 * benchmark_utils::GiB, 10, 5); - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + add_benchmarks(executor); - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_run_length_encode.cpp b/benchmark/benchmark_device_run_length_encode.cpp index 80eb8daba..35512c736 100644 --- a/benchmark/benchmark_device_run_length_encode.cpp +++ b/benchmark/benchmark_device_run_length_encode.cpp @@ -23,14 +23,8 @@ #include "benchmark_device_run_length_encode.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - #include "../common/utils_custom_type.hpp" -// Google Benchmark -#include - // HIP API #include @@ -42,114 +36,40 @@ #include #include -#ifndef DEFAULT_BYTES -constexpr size_t DEFAULT_BYTES = size_t{2} << 30; // 2 GiB -#endif - -#define CREATE_ENCODE_BENCHMARK(T, ML) \ - { \ - const device_run_length_encode_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ - } +#define CREATE_ENCODE_BENCHMARK(T, ML) \ + executor.queue_instance(device_run_length_encode_benchmark()); template -void add_encode_benchmarks(std::vector& benchmarks, - size_t size, - const managed_seed& seed, - hipStream_t stream) +void add_encode_benchmarks(benchmark_utils::executor& executor) { using custom_float2 = common::custom_type; using custom_double2 = common::custom_type; // all tuned types - CREATE_ENCODE_BENCHMARK(int8_t, MaxLength); - CREATE_ENCODE_BENCHMARK(int16_t, MaxLength); - CREATE_ENCODE_BENCHMARK(int32_t, MaxLength); - CREATE_ENCODE_BENCHMARK(int64_t, MaxLength); - CREATE_ENCODE_BENCHMARK(rocprim::int128_t, MaxLength); - CREATE_ENCODE_BENCHMARK(rocprim::uint128_t, MaxLength); - CREATE_ENCODE_BENCHMARK(rocprim::half, MaxLength); - CREATE_ENCODE_BENCHMARK(float, MaxLength); - CREATE_ENCODE_BENCHMARK(double, MaxLength); + CREATE_ENCODE_BENCHMARK(int8_t, MaxLength) + CREATE_ENCODE_BENCHMARK(int16_t, MaxLength) + CREATE_ENCODE_BENCHMARK(int32_t, MaxLength) + CREATE_ENCODE_BENCHMARK(int64_t, MaxLength) + CREATE_ENCODE_BENCHMARK(rocprim::int128_t, MaxLength) + CREATE_ENCODE_BENCHMARK(rocprim::uint128_t, MaxLength) + CREATE_ENCODE_BENCHMARK(rocprim::half, MaxLength) + CREATE_ENCODE_BENCHMARK(float, MaxLength) + CREATE_ENCODE_BENCHMARK(double, MaxLength) // custom types - CREATE_ENCODE_BENCHMARK(custom_float2, MaxLength); - CREATE_ENCODE_BENCHMARK(custom_double2, MaxLength); + CREATE_ENCODE_BENCHMARK(custom_float2, MaxLength) + CREATE_ENCODE_BENCHMARK(custom_double2, MaxLength) } int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif + benchmark_utils::executor executor(argc, argv, 2 * benchmark_utils::GiB, 10, 10); - parser.run_and_exit_if_error(); +#ifndef BENCHMARK_CONFIG_TUNING - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); + add_encode_benchmarks<1000>(executor); + add_encode_benchmarks<10>(executor); - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); - benchmark::AddCustomContext("seed", seed_type); - - std::vector benchmarks; - - // Add benchmarks -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - size, - seed, - stream); -#else - add_encode_benchmarks<1000>(benchmarks, size, seed, stream); - add_encode_benchmarks<10>(benchmarks, size, seed, stream); #endif - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_run_length_encode.parallel.cpp.in b/benchmark/benchmark_device_run_length_encode.parallel.cpp.in index c3d6888fc..2db6eba1e 100644 --- a/benchmark/benchmark_device_run_length_encode.parallel.cpp.in +++ b/benchmark/benchmark_device_run_length_encode.parallel.cpp.in @@ -28,7 +28,7 @@ #include namespace { - auto benchmarks = config_autotune_register::create_bulk( + auto unused = benchmark_utils::executor::queue_autotune( device_run_length_encode_benchmark_generator< @KeyType@, @BlockSize@>::create); diff --git a/benchmark/benchmark_device_run_length_encode.parallel.hpp b/benchmark/benchmark_device_run_length_encode.parallel.hpp index e0a84146a..1f9da997d 100644 --- a/benchmark/benchmark_device_run_length_encode.parallel.hpp +++ b/benchmark/benchmark_device_run_length_encode.parallel.hpp @@ -25,6 +25,8 @@ #include "benchmark_utils.hpp" +#include "../common/utils_device_ptr.hpp" + // Google Benchmark #include @@ -63,7 +65,7 @@ inline std::string run_length_encode_config_name() } template -struct device_run_length_encode_benchmark : public config_autotune_interface +struct device_run_length_encode_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -72,11 +74,12 @@ struct device_run_length_encode_benchmark : public config_autotune_interface + ",keys_max_length:" + std::to_string(MaxLength) + ",cfg:" + run_length_encode_config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = T; using count_type = unsigned int; @@ -105,95 +108,41 @@ struct device_run_length_encode_benchmark : public config_autotune_interface offset += key_count; } - key_type* d_input; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(key_type))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(key_type), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); - key_type* d_unique_output; - count_type* d_counts_output; - count_type* d_runs_count_output; - HIP_CHECK( - hipMalloc(reinterpret_cast(&d_unique_output), runs_count * sizeof(key_type))); - HIP_CHECK( - hipMalloc(reinterpret_cast(&d_counts_output), runs_count * sizeof(count_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_runs_count_output), sizeof(count_type))); + common::device_ptr d_unique_output(runs_count); + common::device_ptr d_counts_output(runs_count); + common::device_ptr d_runs_count_output(1); - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; HIP_CHECK(rocprim::run_length_encode(nullptr, temporary_storage_bytes, - d_input, + d_input.get(), size, - d_unique_output, - d_counts_output, - d_runs_count_output, + d_unique_output.get(), + d_counts_output.get(), + d_runs_count_output.get(), stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < 10; ++i) - { - HIP_CHECK(rocprim::run_length_encode(d_temporary_storage, - temporary_storage_bytes, - d_input, - size, - d_unique_output, - d_counts_output, - d_runs_count_output, - stream, - false)); - } + common::device_ptr d_temporary_storage(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - const unsigned int batch_size = 10; - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::run_length_encode(d_temporary_storage, + HIP_CHECK(rocprim::run_length_encode(d_temporary_storage.get(), temporary_storage_bytes, - d_input, + d_input.get(), size, - d_unique_output, - d_counts_output, - d_runs_count_output, + d_unique_output.get(), + d_counts_output.get(), + d_runs_count_output.get(), stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_unique_output)); - HIP_CHECK(hipFree(d_counts_output)); - HIP_CHECK(hipFree(d_runs_count_output)); + }); + state.set_throughput(size, sizeof(key_type)); } }; @@ -205,7 +154,7 @@ struct device_run_length_encode_benchmark_generator template struct create_ipt { - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { using config = rocprim::reduce_by_key_config>& storage) + static void create(std::vector>& storage) { static constexpr unsigned int max_items_per_thread = std::min(TUNING_SHARED_MEMORY_MAX / sizeof(T) / BlockSize - 1, size_t{15}); diff --git a/benchmark/benchmark_device_run_length_encode_non_trivial_runs.cpp b/benchmark/benchmark_device_run_length_encode_non_trivial_runs.cpp index c5a049f06..a1e0a3d49 100644 --- a/benchmark/benchmark_device_run_length_encode_non_trivial_runs.cpp +++ b/benchmark/benchmark_device_run_length_encode_non_trivial_runs.cpp @@ -42,115 +42,39 @@ #include #include -#ifndef DEFAULT_BYTES -constexpr size_t DEFAULT_BYTES = size_t{2} << 30; // 2 GiB -#endif - // CHANGE -#define CREATE_NON_TRIVIAL_RUNS_BENCHMARK(T, ML) \ - { \ - const device_non_trivial_runs_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ - } +#define CREATE_NON_TRIVIAL_RUNS_BENCHMARK(T, ML) \ + executor.queue_instance(device_non_trivial_runs_benchmark()); template -void add_non_trivial_runs_benchmarks(std::vector& benchmarks, - size_t size, - const managed_seed& seed, - hipStream_t stream) +void add_non_trivial_runs_benchmarks(benchmark_utils::executor& executor) { using custom_float2 = common::custom_type; using custom_double2 = common::custom_type; - CREATE_NON_TRIVIAL_RUNS_BENCHMARK(int8_t, MaxLength); - CREATE_NON_TRIVIAL_RUNS_BENCHMARK(int16_t, MaxLength); - CREATE_NON_TRIVIAL_RUNS_BENCHMARK(int32_t, MaxLength); - CREATE_NON_TRIVIAL_RUNS_BENCHMARK(int64_t, MaxLength); - CREATE_NON_TRIVIAL_RUNS_BENCHMARK(rocprim::int128_t, MaxLength); - CREATE_NON_TRIVIAL_RUNS_BENCHMARK(rocprim::uint128_t, MaxLength); - CREATE_NON_TRIVIAL_RUNS_BENCHMARK(rocprim::half, MaxLength); - CREATE_NON_TRIVIAL_RUNS_BENCHMARK(float, MaxLength); - CREATE_NON_TRIVIAL_RUNS_BENCHMARK(double, MaxLength); - - CREATE_NON_TRIVIAL_RUNS_BENCHMARK(custom_float2, MaxLength); - CREATE_NON_TRIVIAL_RUNS_BENCHMARK(custom_double2, MaxLength); + CREATE_NON_TRIVIAL_RUNS_BENCHMARK(int8_t, MaxLength) + CREATE_NON_TRIVIAL_RUNS_BENCHMARK(int16_t, MaxLength) + CREATE_NON_TRIVIAL_RUNS_BENCHMARK(int32_t, MaxLength) + CREATE_NON_TRIVIAL_RUNS_BENCHMARK(int64_t, MaxLength) + CREATE_NON_TRIVIAL_RUNS_BENCHMARK(rocprim::int128_t, MaxLength) + CREATE_NON_TRIVIAL_RUNS_BENCHMARK(rocprim::uint128_t, MaxLength) + CREATE_NON_TRIVIAL_RUNS_BENCHMARK(rocprim::half, MaxLength) + CREATE_NON_TRIVIAL_RUNS_BENCHMARK(float, MaxLength) + CREATE_NON_TRIVIAL_RUNS_BENCHMARK(double, MaxLength) + // custom types + CREATE_NON_TRIVIAL_RUNS_BENCHMARK(custom_float2, MaxLength) + CREATE_NON_TRIVIAL_RUNS_BENCHMARK(custom_double2, MaxLength) } int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); - benchmark::AddCustomContext("seed", seed_type); - - std::vector benchmarks; - + benchmark_utils::executor executor(argc, argv, 2 * benchmark_utils::GiB, 10, 5); // Add benchmarks -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - size, - seed, - stream); -#else - add_non_trivial_runs_benchmarks<16>(benchmarks, size, seed, stream); - add_non_trivial_runs_benchmarks<256>(benchmarks, size, seed, stream); - add_non_trivial_runs_benchmarks<4096>(benchmarks, size, seed, stream); -#endif +#ifndef BENCHMARK_CONFIG_TUNING - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + add_non_trivial_runs_benchmarks<16>(executor); + add_non_trivial_runs_benchmarks<256>(executor); + add_non_trivial_runs_benchmarks<4096>(executor); +#endif + executor.run(); } diff --git a/benchmark/benchmark_device_run_length_encode_non_trivial_runs.parallel.cpp.in b/benchmark/benchmark_device_run_length_encode_non_trivial_runs.parallel.cpp.in index cf9e88a84..f0b99757b 100644 --- a/benchmark/benchmark_device_run_length_encode_non_trivial_runs.parallel.cpp.in +++ b/benchmark/benchmark_device_run_length_encode_non_trivial_runs.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -29,7 +29,7 @@ #include namespace { - auto benchmarks = config_autotune_register::create_bulk( + auto unused = benchmark_utils::executor::queue_autotune( device_non_trivial_runs_benchmark_generator< @KeyType@, @BlockSize@, diff --git a/benchmark/benchmark_device_run_length_encode_non_trivial_runs.parallel.hpp b/benchmark/benchmark_device_run_length_encode_non_trivial_runs.parallel.hpp index add6a686e..247d820da 100644 --- a/benchmark/benchmark_device_run_length_encode_non_trivial_runs.parallel.hpp +++ b/benchmark/benchmark_device_run_length_encode_non_trivial_runs.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,8 @@ #include "benchmark_utils.hpp" +#include "../common/utils_device_ptr.hpp" + // Google Benchmark #include @@ -69,7 +71,7 @@ inline std::string non_trivial_runs_config_name() } template -struct device_non_trivial_runs_benchmark : public config_autotune_interface +struct device_non_trivial_runs_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -79,16 +81,15 @@ struct device_non_trivial_runs_benchmark : public config_autotune_interface + ",cfg:" + non_trivial_runs_config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using offset_type = unsigned int; using count_type = unsigned int; - constexpr int batch_size = 10; - constexpr int warmup_size = 5; constexpr std::array tuning_max_segment_lengths = {10, 1000}; constexpr int num_input_arrays = is_tuning ? tuning_max_segment_lengths.size() : 1; @@ -112,22 +113,15 @@ struct device_non_trivial_runs_benchmark : public config_autotune_interface input[0] = get_random_segments_iota(size, MaxLength, seed.get_0()); } - T* d_input[num_input_arrays]; + common::device_ptr d_input[num_input_arrays]; for(int i = 0; i < num_input_arrays; ++i) { - HIP_CHECK(hipMalloc(&d_input[i], size * sizeof(*d_input[i]))); - HIP_CHECK(hipMemcpy(d_input[i], - input[i].data(), - size * sizeof(*d_input[i]), - hipMemcpyHostToDevice)); + d_input[i].store(input[i]); } - offset_type* d_offsets_output; - HIP_CHECK(hipMalloc(&d_offsets_output, size * sizeof(*d_offsets_output))); - count_type* d_counts_output; - HIP_CHECK(hipMalloc(&d_counts_output, size * sizeof(*d_counts_output))); - count_type* d_runs_count_output; - HIP_CHECK(hipMalloc(&d_runs_count_output, sizeof(*d_runs_count_output))); + common::device_ptr d_offsets_output(size); + common::device_ptr d_counts_output(size); + common::device_ptr d_runs_count_output(1); const auto dispatch = [&](void* d_temporary_storage, size_t& temporary_storage_bytes) { @@ -138,76 +132,28 @@ struct device_non_trivial_runs_benchmark : public config_autotune_interface temporary_storage_bytes, d_input, size, - d_offsets_output, - d_counts_output, - d_runs_count_output, + d_offsets_output.get(), + d_counts_output.get(), + d_runs_count_output.get(), stream, false)); }; for(int i = 0; i < num_input_arrays; ++i) { - dispatch_input(d_input[i]); + dispatch_input(d_input[i].get()); } }; // Allocate temporary storage memory size_t temporary_storage_bytes = 0; dispatch(nullptr, temporary_storage_bytes); - void* d_temporary_storage; - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(int i = 0; i < warmup_size; ++i) - { - dispatch(d_temporary_storage, temporary_storage_bytes); - } + common::device_ptr d_temporary_storage(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(int i = 0; i < batch_size; ++i) - { - dispatch(d_temporary_storage, temporary_storage_bytes); - } - HIP_CHECK(hipStreamSynchronize(stream)); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - HIP_CHECK(hipGetLastError()); - HIP_CHECK(hipDeviceSynchronize()); - - float elapsed_mseconds{}; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * item_size); - state.SetItemsProcessed(state.iterations() * batch_size * size); + state.run([&] { dispatch(d_temporary_storage.get(), temporary_storage_bytes); }); - HIP_CHECK(hipFree(d_temporary_storage)); - for(int i = 0; i < num_input_arrays; ++i) - { - HIP_CHECK(hipFree(d_input[i])); - } - HIP_CHECK(hipFree(d_offsets_output)); - HIP_CHECK(hipFree(d_counts_output)); - HIP_CHECK(hipFree(d_runs_count_output)); + state.set_throughput(size, sizeof(T) + sizeof(offset_type) + sizeof(count_type)); } static constexpr bool is_tuning = !std::is_same::value; }; @@ -238,7 +184,7 @@ struct device_non_trivial_runs_benchmark_generator template struct create_ipt { - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { if(!is_load_warp_transpose || is_warp_load_supp) { @@ -256,7 +202,7 @@ struct device_non_trivial_runs_benchmark_generator static constexpr unsigned int items_per_thread = 1u << ItemsPerThreadExp; }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { static_for_each< make_index_range, diff --git a/benchmark/benchmark_device_scan.cpp b/benchmark/benchmark_device_scan.cpp index 4a9d2d2b1..a2972fa50 100644 --- a/benchmark/benchmark_device_scan.cpp +++ b/benchmark/benchmark_device_scan.cpp @@ -22,132 +22,14 @@ #include "benchmark_device_scan.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - -#ifndef BENCHMARK_CONFIG_TUNING - #include "../common/utils_custom_type.hpp" -#endif - -// Google Benchmark -#include - -// HIP API -#include - -#ifndef BENCHMARK_CONFIG_TUNING - #include - #include -#endif - -#include -#include -#include -#ifndef BENCHMARK_CONFIG_TUNING - #include -#endif - -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_EXCL_INCL_BENCHMARK(EXCL, T, SCAN_OP) \ - { \ - const device_scan_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_BENCHMARK(T, SCAN_OP) \ - CREATE_EXCL_INCL_BENCHMARK(false, T, SCAN_OP) \ - CREATE_EXCL_INCL_BENCHMARK(true, T, SCAN_OP) int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else - using custom_float2 = common::custom_type; - using custom_double2 = common::custom_type; - - CREATE_BENCHMARK(int, rocprim::plus) - CREATE_BENCHMARK(float, rocprim::plus) - CREATE_BENCHMARK(double, rocprim::plus) - CREATE_BENCHMARK(long long, rocprim::plus) - CREATE_BENCHMARK(float2, rocprim::plus) - CREATE_BENCHMARK(custom_float2, rocprim::plus) - CREATE_BENCHMARK(double2, rocprim::plus) - CREATE_BENCHMARK(custom_double2, rocprim::plus) - CREATE_BENCHMARK(int8_t, rocprim::plus) - CREATE_BENCHMARK(uint8_t, rocprim::plus) - CREATE_BENCHMARK(rocprim::half, rocprim::plus) - CREATE_BENCHMARK(rocprim::int128_t, rocprim::plus) - CREATE_BENCHMARK(rocprim::uint128_t, rocprim::plus) +#ifndef BENCHMARK_CONFIG_TUNING + add_benchmarks(executor); #endif - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_scan.parallel.cpp.in b/benchmark/benchmark_device_scan.parallel.cpp.in index 86bc54475..7f4da1b82 100644 --- a/benchmark/benchmark_device_scan.parallel.cpp.in +++ b/benchmark/benchmark_device_scan.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,6 @@ namespace { -auto benchmarks = config_autotune_register::create_bulk( +auto unused = benchmark_utils::executor::queue_autotune( device_scan_benchmark_generator<@DataType@, rocprim::block_scan_algorithm::@Algo@>::create); } // namespace diff --git a/benchmark/benchmark_device_scan.parallel.hpp b/benchmark/benchmark_device_scan.parallel.hpp index 3dbc0b8b1..f34d56174 100644 --- a/benchmark/benchmark_device_scan.parallel.hpp +++ b/benchmark/benchmark_device_scan.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,11 @@ #include "benchmark_utils.hpp" +#include "../common/utils_device_ptr.hpp" +#ifndef BENCHMARK_CONFIG_TUNING + #include "../common/utils_custom_type.hpp" +#endif + // Google Benchmark #include @@ -41,6 +46,9 @@ #include #include #include +#else + #include + #include #endif #include @@ -49,6 +57,8 @@ #include #ifdef BENCHMARK_CONFIG_TUNING #include +#else + #include #endif template @@ -66,12 +76,12 @@ inline std::string config_name() return "default_config"; } -template, - bool Deterministic = false, - typename Config = rocprim::default_config> -struct device_scan_benchmark : public config_autotune_interface +template +struct device_scan_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -93,7 +103,7 @@ struct device_scan_benchmark : public config_autotune_interface const bool debug = false) const -> typename std::enable_if::type { - if ROCPRIM_IF_CONSTEXPR(!Deterministic) + if constexpr(!Deterministic) { return rocprim::exclusive_scan(temporary_storage, storage_size, @@ -132,7 +142,7 @@ struct device_scan_benchmark : public config_autotune_interface typename std::enable_if::type { (void)initial_value; - if ROCPRIM_IF_CONSTEXPR(!Deterministic) + if constexpr(!Deterministic) { return rocprim::inclusive_scan(temporary_storage, storage_size, @@ -156,11 +166,12 @@ struct device_scan_benchmark : public config_autotune_interface } } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(T); @@ -168,85 +179,39 @@ struct device_scan_benchmark : public config_autotune_interface const auto random_range = limit_random_range(0, 1000); std::vector input = get_random_data(size, random_range.first, random_range.second, seed.get_0()); - T initial_value = T(123); - T* d_input; - T* d_output; - HIP_CHECK(hipMalloc(&d_input, size * sizeof(T))); - HIP_CHECK(hipMalloc(&d_output, size * sizeof(T))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); + T initial_value = T(123); + common::device_ptr d_input(input); + common::device_ptr d_output(size); HIP_CHECK(hipDeviceSynchronize()); // Allocate temporary storage memory size_t temp_storage_size_bytes; - void* d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK((run_device_scan(d_temp_storage, + HIP_CHECK((run_device_scan(nullptr, temp_storage_size_bytes, - d_input, - d_output, + d_input.get(), + d_output.get(), initial_value, size, scan_op, stream))); - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < 5; ++i) - { - HIP_CHECK((run_device_scan(d_temp_storage, - temp_storage_size_bytes, - d_input, - d_output, - initial_value, - size, - scan_op, - stream))); - } + common::device_ptr d_temp_storage(temp_storage_size_bytes); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - const unsigned int batch_size = 10; - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK((run_device_scan(d_temp_storage, + HIP_CHECK((run_device_scan(d_temp_storage.get(), temp_storage_size_bytes, - d_input, - d_output, + d_input.get(), + d_output.get(), initial_value, size, scan_op, stream))); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); + }); - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(T)); } }; @@ -264,7 +229,8 @@ struct device_scan_benchmark_generator template struct create_ipt { - void operator()(std::vector>& storage) + void operator()( + std::vector>& storage) { storage.emplace_back( std::make_unique>& storage) + void operator()( + std::vector>& storage) { // Limit items per thread to not over-use shared memory static constexpr unsigned int max_items_per_thread @@ -292,19 +259,50 @@ struct device_scan_benchmark_generator static constexpr unsigned int block_size = 1u << BlockSizeExponent; }; - static void create(std::vector>& storage) + static void + create(std::vector>& storage) { static_for_each(storage); } }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { // Block sizes 64, 128, 256 create_block_scan_algorithm>::create(storage); } }; +#else + + #define CREATE_EXCL_INCL_BENCHMARK(EXCL, T, SCAN_OP) \ + executor.queue_instance(device_scan_benchmark()); + + #define CREATE_BENCHMARK(T, SCAN_OP) \ + CREATE_EXCL_INCL_BENCHMARK(false, T, SCAN_OP) \ + CREATE_EXCL_INCL_BENCHMARK(true, T, SCAN_OP) + +template +void add_benchmarks(benchmark_utils::executor& executor) +{ + using custom_float2 = common::custom_type; + using custom_double2 = common::custom_type; + + CREATE_BENCHMARK(int, rocprim::plus) + CREATE_BENCHMARK(float, rocprim::plus) + CREATE_BENCHMARK(double, rocprim::plus) + CREATE_BENCHMARK(long long, rocprim::plus) + CREATE_BENCHMARK(float2, rocprim::plus) + CREATE_BENCHMARK(custom_float2, rocprim::plus) + CREATE_BENCHMARK(double2, rocprim::plus) + CREATE_BENCHMARK(custom_double2, rocprim::plus) + CREATE_BENCHMARK(int8_t, rocprim::plus) + CREATE_BENCHMARK(uint8_t, rocprim::plus) + CREATE_BENCHMARK(rocprim::half, rocprim::plus) + CREATE_BENCHMARK(rocprim::int128_t, rocprim::plus) + CREATE_BENCHMARK(rocprim::uint128_t, rocprim::plus) +} + #endif // BENCHMARK_CONFIG_TUNING #endif // ROCPRIM_BENCHMARK_DEVICE_SCAN_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_scan_by_key.cpp b/benchmark/benchmark_device_scan_by_key.cpp index 6b1931f90..87df5b436 100644 --- a/benchmark/benchmark_device_scan_by_key.cpp +++ b/benchmark/benchmark_device_scan_by_key.cpp @@ -22,145 +22,14 @@ #include "benchmark_device_scan_by_key.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - -#ifndef BENCHMARK_CONFIG_TUNING - #include "../common/utils_custom_type.hpp" -#endif - -// Google Benchmark -#include - -// HIP API -#include - -#ifndef BENCHMARK_CONFIG_TUNING - #include - #include -#endif - -#include -#include -#include -#ifndef BENCHMARK_CONFIG_TUNING - #include -#endif - -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, MAX_SEGMENT_LENGTH) \ - { \ - const device_scan_by_key_benchmark, \ - MAX_SEGMENT_LENGTH> \ - instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_EXCL_INCL_BENCHMARK(EXCL, T, SCAN_OP) \ - CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 1) \ - CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 16) \ - CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 256) \ - CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 4096) \ - CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 65536) - -#define CREATE_BENCHMARK(T, SCAN_OP) \ - CREATE_EXCL_INCL_BENCHMARK(false, T, SCAN_OP) \ - CREATE_EXCL_INCL_BENCHMARK(true, T, SCAN_OP) int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else - using custom_float2 = common::custom_type; - using custom_double2 = common::custom_type; - - CREATE_BENCHMARK(int, rocprim::plus) - CREATE_BENCHMARK(float, rocprim::plus) - CREATE_BENCHMARK(double, rocprim::plus) - CREATE_BENCHMARK(long long, rocprim::plus) - CREATE_BENCHMARK(float2, rocprim::plus) - CREATE_BENCHMARK(custom_float2, rocprim::plus) - CREATE_BENCHMARK(double2, rocprim::plus) - CREATE_BENCHMARK(custom_double2, rocprim::plus) - CREATE_BENCHMARK(int8_t, rocprim::plus) - CREATE_BENCHMARK(uint8_t, rocprim::plus) - CREATE_BENCHMARK(rocprim::half, rocprim::plus) - CREATE_BENCHMARK(rocprim::int128_t, rocprim::plus) - CREATE_BENCHMARK(rocprim::uint128_t, rocprim::plus) +#ifndef BENCHMARK_CONFIG_TUNING + add_benchmarks(executor); #endif - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_scan_by_key.parallel.cpp.in b/benchmark/benchmark_device_scan_by_key.parallel.cpp.in index 2e7b2b641..1fb1ef40c 100644 --- a/benchmark/benchmark_device_scan_by_key.parallel.cpp.in +++ b/benchmark/benchmark_device_scan_by_key.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,6 @@ namespace { -auto benchmarks = config_autotune_register::create_bulk( +auto unused = benchmark_utils::executor::queue_autotune( device_scan_by_key_benchmark_generator<@KeyType@, @ValueType@, rocprim::block_scan_algorithm::@Algo@>::create); } // namespace diff --git a/benchmark/benchmark_device_scan_by_key.parallel.hpp b/benchmark/benchmark_device_scan_by_key.parallel.hpp index 089ff4cd3..18d2606ce 100644 --- a/benchmark/benchmark_device_scan_by_key.parallel.hpp +++ b/benchmark/benchmark_device_scan_by_key.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,11 @@ #include "benchmark_utils.hpp" +#include "../common/utils_device_ptr.hpp" +#ifndef BENCHMARK_CONFIG_TUNING + #include "../common/utils_custom_type.hpp" +#endif + // Google Benchmark #include @@ -41,6 +46,9 @@ #include #include #include +#else + #include + #include #endif #include @@ -49,6 +57,8 @@ #include #ifdef BENCHMARK_CONFIG_TUNING #include +#else + #include #endif template @@ -66,15 +76,15 @@ inline std::string config_name() return "default_config"; } -template, - typename CompareOp = rocprim::equal_to, - unsigned int MaxSegmentLength = 1024, - bool Deterministic = false, - typename Config = rocprim::default_config> -struct device_scan_by_key_benchmark : public config_autotune_interface +template +struct device_scan_by_key_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -100,7 +110,7 @@ struct device_scan_by_key_benchmark : public config_autotune_interface const bool debug = false) const -> typename std::enable_if::type { - if ROCPRIM_IF_CONSTEXPR(!Deterministic) + if constexpr(!Deterministic) { return rocprim::exclusive_scan_by_key(temporary_storage, storage_size, @@ -144,7 +154,7 @@ struct device_scan_by_key_benchmark : public config_autotune_interface const bool debug = false) const -> typename std::enable_if::type { - if ROCPRIM_IF_CONSTEXPR(!Deterministic) + if constexpr(!Deterministic) { return rocprim::inclusive_scan_by_key(temporary_storage, storage_size, @@ -172,11 +182,12 @@ struct device_scan_by_key_benchmark : public config_autotune_interface } } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(Value); @@ -193,101 +204,44 @@ struct device_scan_by_key_benchmark : public config_autotune_interface ScanOp scan_op{}; CompareOp compare_op{}; - Value initial_value = Value(123); - Value* d_input; - Key* d_keys; - Value* d_output; - HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(input[0]))); - HIP_CHECK(hipMalloc(&d_keys, keys.size() * sizeof(keys[0]))); - HIP_CHECK(hipMalloc(&d_output, input.size() * sizeof(input[0]))); - HIP_CHECK(hipMemcpy(d_input, - input.data(), - input.size() * sizeof(input[0]), - hipMemcpyHostToDevice)); - HIP_CHECK( - hipMemcpy(d_keys, keys.data(), keys.size() * sizeof(keys[0]), hipMemcpyHostToDevice)); + Value initial_value = Value(123); + common::device_ptr d_input(input); + common::device_ptr d_keys(keys); + common::device_ptr d_output(input.size()); // Allocate temporary storage memory size_t temp_storage_size_bytes; - void* d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK((run_device_scan_by_key(d_temp_storage, + HIP_CHECK((run_device_scan_by_key(nullptr, temp_storage_size_bytes, - d_keys, - d_input, - d_output, + d_keys.get(), + d_input.get(), + d_output.get(), initial_value, size, scan_op, compare_op, stream, debug))); - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); - - // Warm-up - for(size_t i = 0; i < 5; ++i) - { - HIP_CHECK((run_device_scan_by_key(d_temp_storage, - temp_storage_size_bytes, - d_keys, - d_input, - d_output, - initial_value, - size, - scan_op, - compare_op, - stream, - debug))); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - const unsigned int batch_size = 10; - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); + common::device_ptr d_temp_storage(temp_storage_size_bytes); - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK((run_device_scan_by_key(d_temp_storage, + HIP_CHECK((run_device_scan_by_key(d_temp_storage.get(), temp_storage_size_bytes, - d_keys, - d_input, - d_output, + d_keys.get(), + d_input.get(), + d_output.get(), initial_value, size, scan_op, compare_op, stream, debug))); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + }); - state.SetBytesProcessed(state.iterations() * batch_size * size - * (sizeof(Key) + sizeof(Value))); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_keys)); - HIP_CHECK(hipFree(d_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(Key) + sizeof(Value)); } }; @@ -305,7 +259,8 @@ struct device_scan_by_key_benchmark_generator template struct create_ipt { - void operator()(std::vector>& storage) + void operator()( + std::vector>& storage) { storage.emplace_back(std::make_unique>& storage) + void operator()( + std::vector>& storage) { // Limit items per thread to not over-use shared memory static constexpr unsigned int max_items_per_thread = ::rocprim::min( @@ -340,19 +296,63 @@ struct device_scan_by_key_benchmark_generator static constexpr unsigned int block_size = 1u << BlockSizeExponent; }; - static void create(std::vector>& storage) + static void + create(std::vector>& storage) { static_for_each(storage); } }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { // Block sizes 64, 128, 256 create_block_scan_algorithm>::create(storage); } }; +#else // BENCHMARK_CONFIG_TUNING + + #define CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, MAX_SEGMENT_LENGTH) \ + executor.queue_instance(device_scan_by_key_benchmark, \ + MAX_SEGMENT_LENGTH, \ + Deterministic>()); + + #define CREATE_EXCL_INCL_BENCHMARK(EXCL, T, SCAN_OP) \ + CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 1) \ + CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 16) \ + CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 256) \ + CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 4096) \ + CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 65536) + + #define CREATE_BENCHMARK(T, SCAN_OP) \ + CREATE_EXCL_INCL_BENCHMARK(false, T, SCAN_OP) \ + CREATE_EXCL_INCL_BENCHMARK(true, T, SCAN_OP) + +template +void add_benchmarks(benchmark_utils::executor& executor) +{ + using custom_float2 = common::custom_type; + using custom_double2 = common::custom_type; + + CREATE_BENCHMARK(int, rocprim::plus) + CREATE_BENCHMARK(float, rocprim::plus) + CREATE_BENCHMARK(double, rocprim::plus) + CREATE_BENCHMARK(long long, rocprim::plus) + CREATE_BENCHMARK(float2, rocprim::plus) + CREATE_BENCHMARK(custom_float2, rocprim::plus) + CREATE_BENCHMARK(double2, rocprim::plus) + CREATE_BENCHMARK(custom_double2, rocprim::plus) + CREATE_BENCHMARK(int8_t, rocprim::plus) + CREATE_BENCHMARK(uint8_t, rocprim::plus) + CREATE_BENCHMARK(rocprim::half, rocprim::plus) + CREATE_BENCHMARK(rocprim::int128_t, rocprim::plus) + CREATE_BENCHMARK(rocprim::uint128_t, rocprim::plus) +} + #endif // BENCHMARK_CONFIG_TUNING #endif // ROCPRIM_BENCHMARK_DEVICE_SCAN_BY_KEY_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_scan_by_key_deterministic.cpp b/benchmark/benchmark_device_scan_by_key_deterministic.cpp index 33c59a33f..c4fb5b1b3 100644 --- a/benchmark/benchmark_device_scan_by_key_deterministic.cpp +++ b/benchmark/benchmark_device_scan_by_key_deterministic.cpp @@ -22,118 +22,14 @@ #include "benchmark_device_scan_by_key.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - -#include "../common/utils_custom_type.hpp" - -// Google Benchmark -#include - -// HIP API -#include - -#include -#include - -#include -#include -#include -#include - -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, MAX_SEGMENT_LENGTH) \ - { \ - const device_scan_by_key_benchmark, \ - MAX_SEGMENT_LENGTH, \ - true> \ - instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_EXCL_INCL_BENCHMARK(EXCL, T, SCAN_OP) \ - CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 1) \ - CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 16) \ - CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 256) \ - CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 4096) \ - CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, 65536) - -#define CREATE_BENCHMARK(T, SCAN_OP) \ - CREATE_EXCL_INCL_BENCHMARK(false, T, SCAN_OP) \ - CREATE_EXCL_INCL_BENCHMARK(true, T, SCAN_OP) int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; - using custom_float2 = common::custom_type; - using custom_double2 = common::custom_type; + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - CREATE_BENCHMARK(int, rocprim::plus) - CREATE_BENCHMARK(float, rocprim::plus) - CREATE_BENCHMARK(double, rocprim::plus) - CREATE_BENCHMARK(long long, rocprim::plus) - CREATE_BENCHMARK(float2, rocprim::plus) - CREATE_BENCHMARK(custom_float2, rocprim::plus) - CREATE_BENCHMARK(double2, rocprim::plus) - CREATE_BENCHMARK(custom_double2, rocprim::plus) - CREATE_BENCHMARK(int8_t, rocprim::plus) - CREATE_BENCHMARK(uint8_t, rocprim::plus) - CREATE_BENCHMARK(rocprim::half, rocprim::plus) - CREATE_BENCHMARK(rocprim::int128_t, rocprim::plus) - CREATE_BENCHMARK(rocprim::uint128_t, rocprim::plus) - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); +#ifndef BENCHMARK_CONFIG_TUNING + add_benchmarks(executor); +#endif - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_scan_deterministic.cpp b/benchmark/benchmark_device_scan_deterministic.cpp index c25424745..82b77757d 100644 --- a/benchmark/benchmark_device_scan_deterministic.cpp +++ b/benchmark/benchmark_device_scan_deterministic.cpp @@ -22,104 +22,12 @@ #include "benchmark_device_scan.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - -#include "../common/utils_custom_type.hpp" - -// Google Benchmark -#include - -// HIP API -#include - -#include -#include - -#include -#include -#include -#include - -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_EXCL_INCL_BENCHMARK(EXCL, T, SCAN_OP) \ - { \ - const device_scan_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_BENCHMARK(T, SCAN_OP) \ - CREATE_EXCL_INCL_BENCHMARK(false, T, SCAN_OP) \ - CREATE_EXCL_INCL_BENCHMARK(true, T, SCAN_OP) int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; - using custom_float2 = common::custom_type; - using custom_double2 = common::custom_type; - - CREATE_BENCHMARK(int, rocprim::plus) - CREATE_BENCHMARK(float, rocprim::plus) - CREATE_BENCHMARK(double, rocprim::plus) - CREATE_BENCHMARK(long long, rocprim::plus) - CREATE_BENCHMARK(float2, rocprim::plus) - CREATE_BENCHMARK(custom_float2, rocprim::plus) - CREATE_BENCHMARK(double2, rocprim::plus) - CREATE_BENCHMARK(custom_double2, rocprim::plus) - CREATE_BENCHMARK(int8_t, rocprim::plus) - CREATE_BENCHMARK(uint8_t, rocprim::plus) - CREATE_BENCHMARK(rocprim::half, rocprim::plus) - CREATE_BENCHMARK(rocprim::int128_t, rocprim::plus) - CREATE_BENCHMARK(rocprim::uint128_t, rocprim::plus) - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); + add_benchmarks(executor); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_search.cpp b/benchmark/benchmark_device_search.cpp index 1abbee457..b44bb4901 100644 --- a/benchmark/benchmark_device_search.cpp +++ b/benchmark/benchmark_device_search.cpp @@ -23,33 +23,14 @@ #include "benchmark_device_search.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - #include "../common/utils_custom_type.hpp" -// Google Benchmark -#include - -// HIP API -#include - #include -#include #include -#include -#include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_BENCHMARK_SEARCH(TYPE, KEY_SIZE, REPEATING) \ - { \ - const device_search_benchmark instance(KEY_SIZE, REPEATING); \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } +#define CREATE_BENCHMARK_SEARCH(TYPE, KEY_SIZE, REPEATING) \ + executor.queue_instance(device_search_benchmark(KEY_SIZE, REPEATING)); #define CREATE_BENCHMARK_PATTERN(TYPE, REPEATING) \ { \ @@ -66,34 +47,8 @@ const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("bytes", "bytes", DEFAULT_BYTES, "number of values"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("bytes"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks{}; CREATE_BENCHMARK(int) CREATE_BENCHMARK(long long) CREATE_BENCHMARK(int8_t) @@ -116,23 +71,5 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(custom_char_double) CREATE_BENCHMARK(custom_longlong_double) - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_search.hpp b/benchmark/benchmark_device_search.hpp index db6db6d1e..8312436df 100644 --- a/benchmark/benchmark_device_search.hpp +++ b/benchmark/benchmark_device_search.hpp @@ -26,6 +26,7 @@ #include "benchmark_utils.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -44,7 +45,7 @@ #include template -struct device_search_benchmark : public config_autotune_interface +struct device_search_benchmark : public benchmark_utils::autotune_interface { size_t key_size_ = 10; bool repeating_ = false; @@ -64,14 +65,12 @@ struct device_search_benchmark : public config_autotune_interface + ",value_type:" + std::string(Traits::name()) + ",cfg:default_config}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using key_type = Key; using output_type = size_t; @@ -105,98 +104,43 @@ struct device_search_benchmark : public config_autotune_interface seed.get_0() + 1); } - key_type* d_keys_input; - key_type* d_input; - output_type* d_output; - HIP_CHECK(hipMalloc(&d_keys_input, key_size * sizeof(*d_keys_input))); - HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); - HIP_CHECK(hipMalloc(&d_output, sizeof(*d_output))); - - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); - - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - key_size * sizeof(*d_keys_input), - hipMemcpyHostToDevice)); + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_input(input); + common::device_ptr d_output(1); rocprim::equal_to compare_op; - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::search(d_temporary_storage, + HIP_CHECK(rocprim::search(nullptr, temporary_storage_bytes, - d_input, - d_keys_input, - d_output, + d_input.get(), + d_keys_input.get(), + d_output.get(), size, key_size, compare_op, stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::search(d_temporary_storage, - temporary_storage_bytes, - d_input, - d_keys_input, - d_output, - size, - key_size, - compare_op, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); + common::device_ptr d_temporary_storage(temporary_storage_bytes); - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::search(d_temporary_storage, + HIP_CHECK(rocprim::search(d_temporary_storage.get(), temporary_storage_bytes, - d_input, - d_keys_input, - d_output, + d_input.get(), + d_keys_input.get(), + d_output.get(), size, key_size, compare_op, stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(*d_input)); - state.SetItemsProcessed(state.iterations() * batch_size * size); + }); - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + state.set_throughput(size, sizeof(key_type)); } }; diff --git a/benchmark/benchmark_device_search_n.cpp b/benchmark/benchmark_device_search_n.cpp index ceaddfc20..f846904a6 100644 --- a/benchmark/benchmark_device_search_n.cpp +++ b/benchmark/benchmark_device_search_n.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -22,11 +22,6 @@ #include "benchmark_device_search_n.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - -// Google Benchmark -#include // HIP API #include @@ -35,56 +30,38 @@ #include #include -int main(int argc, char* argv[]) -{ - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", size_t{2} << 30, "number of input bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); +#define CREATE_BENCHMARK(T, S, C) executor.queue_instance(benchmark_search_n()); - // HIP - hipStream_t stream = 0; // default +#define CREATE_BENCHMARKS(T) \ + CREATE_BENCHMARK(T, size_t, count_equal_to<1>) \ + CREATE_BENCHMARK(T, size_t, count_equal_to<6>) \ + CREATE_BENCHMARK(T, size_t, count_equal_to<10>) \ + CREATE_BENCHMARK(T, size_t, count_equal_to<14>) \ + CREATE_BENCHMARK(T, size_t, count_equal_to<25>) \ + CREATE_BENCHMARK(T, size_t, count_is_percent_of_size<50>) \ + CREATE_BENCHMARK(T, size_t, count_is_percent_of_size<100>) - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks{}; - add_benchmark_search_n(benchmarks, seed, stream, size); +int main(int argc, char* argv[]) +{ + benchmark_utils::executor executor(argc, argv, 2 * benchmark_utils::GiB, 10, 10); - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } +#ifndef BENCHMARK_CONFIG_TUNING + using custom_int2 = common::custom_type; + using custom_longlong_double = common::custom_type; - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + CREATE_BENCHMARKS(custom_int2) + CREATE_BENCHMARKS(custom_longlong_double) + CREATE_BENCHMARKS(int8_t) + CREATE_BENCHMARKS(int16_t) + CREATE_BENCHMARKS(int32_t) + CREATE_BENCHMARKS(int64_t) + CREATE_BENCHMARKS(rocprim::int128_t) + CREATE_BENCHMARKS(rocprim::uint128_t) + CREATE_BENCHMARKS(rocprim::half) + CREATE_BENCHMARKS(float) + CREATE_BENCHMARKS(double) +#endif // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - clean_up_benchmarks_search_n(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_search_n.parallel.cpp.in b/benchmark/benchmark_device_search_n.parallel.cpp.in index a988f23df..a6fe19d1a 100644 --- a/benchmark/benchmark_device_search_n.parallel.cpp.in +++ b/benchmark/benchmark_device_search_n.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,8 +28,10 @@ #include namespace { - auto benchmarks = config_autotune_register::create_bulk( + auto unused = benchmark_utils::executor::queue_autotune( device_search_n_benchmark_generator< @InputType@, - @BlockSize@>::create); + @BlockSize@, + @ItemsPerThread@, + @Threshold@>::create); } diff --git a/benchmark/benchmark_device_search_n.parallel.hpp b/benchmark/benchmark_device_search_n.parallel.hpp index 7c2268151..17607aafe 100644 --- a/benchmark/benchmark_device_search_n.parallel.hpp +++ b/benchmark/benchmark_device_search_n.parallel.hpp @@ -27,6 +27,7 @@ #include "cmdparser.hpp" #include "../common/utils_custom_type.hpp" +#include "../common/utils_device_ptr.hpp" // gbench #include @@ -36,24 +37,23 @@ // rocPRIM #include +#include #include -#include #include -#include +#ifndef BENCHMARK_CONFIG_TUNING + #include +#endif -// C++ Standard Library -#include #include -#include -#include -#include #include -#include #include - -using custom_int2 = common::custom_type; -using custom_double2 = common::custom_type; -using custom_longlong_double = common::custom_type; +#ifdef BENCHMARK_CONFIG_TUNING + #include +#else + #include + #include + #include +#endif namespace { @@ -75,444 +75,149 @@ constexpr bool is_type_arr_end = true; template constexpr bool is_type_arr_end> = false; -template -inline unsigned int search_n_get_item_per_block() +template +std::string search_n_config_name() { - using input_type = InputType; - using config = Config; - using wrapped_config = rocprim::detail::wrapped_search_n_config; - - hipStream_t stream = 0; // default - rocprim::detail::target_arch target_arch; - HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); - const auto params = rocprim::detail::dispatch_target_arch(target_arch); - const unsigned int block_size = params.kernel_config.block_size; - const unsigned int items_per_thread = params.kernel_config.items_per_thread; - const unsigned int items_per_block = block_size * items_per_thread; - return items_per_block; + const rocprim::detail::search_n_config_params config = Config(); + return "{bs:" + std::to_string(config.kernel_config.block_size) + + ",ipt:" + std::to_string(config.kernel_config.items_per_thread) + + ",threshold:" + std::to_string(config.threshold) + "}"; } -enum class benchmark_search_n_mode -{ - NORMAL = 0, - NOISE = 1, -}; - -inline std::string to_string(benchmark_search_n_mode e) noexcept +#ifndef BENCHMARK_CONFIG_TUNING +template<> +std::string search_n_config_name() { - switch(e) - { - case benchmark_search_n_mode::NORMAL: return "NORMAL"; - case benchmark_search_n_mode::NOISE: return "NOISE"; - default: return "UNKNOWN"; - } + return "default_config"; } +#endif -} // namespace - -template -class benchmark_search_n +template +struct count_equal_to { -public: - const managed_seed seed; - const hipStream_t stream; - size_t size_byte; - size_t count_byte; - size_t start_pos_byte; - InputType value; - std::vector input; - -private: - size_t size; - size_t count; - size_t start_pos; - const size_t warmup_size = 10; - const size_t batch_size = 10; - size_t temp_storage_size = 0; - size_t noise_sequence = 0; - bool create_noise = false; - - hipEvent_t start; - hipEvent_t stop; - - void* d_temp_storage = nullptr; - InputType* d_input; - OutputType* d_output; - InputType* d_value; - - void create() noexcept + std::string name() const { - switch(mode) - { - case benchmark_search_n_mode::NORMAL: - { - input.resize(size); - if(start_pos + count < size) - { - std::fill(input.begin(), input.begin() + start_pos, 0); - std::fill(input.begin() + start_pos, - input.begin() + count + start_pos, - value); - std::fill(input.begin() + count + start_pos, input.end(), 0); - } - else - { - std::fill(input.begin(), input.end(), 0); - } - break; - } - case benchmark_search_n_mode::NOISE: - { - InputType h_noise{0}; - input = std::vector(size, value); - - if(create_noise) - { - size_t cur_tile = 0; - size_t last_tile = size / count - 1; - while(cur_tile != last_tile) - { - input[cur_tile * count + count - 1] = h_noise; - ++cur_tile; - } - } - break; - } - default: - { - break; - } - } - - HIP_CHECK(hipMallocAsync(&d_value, sizeof(InputType), stream)); - HIP_CHECK(hipMallocAsync(&d_input, sizeof(InputType) * input.size(), stream)); - HIP_CHECK(hipMallocAsync(&d_output, sizeof(OutputType), stream)); - HIP_CHECK( - hipMemcpyAsync(d_value, &value, sizeof(InputType), hipMemcpyHostToDevice, stream)); - HIP_CHECK(hipMemcpyAsync(d_input, - input.data(), - sizeof(InputType) * input.size(), - hipMemcpyHostToDevice, - stream)); - - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); + return "count_equal_to<" + std::to_string(Value) + ">"; } - - void release() noexcept + constexpr size_t resolve(size_t) const { - decltype(input) tmp; - input.swap(tmp); // clear input memspace - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - HIP_CHECK(hipFree(d_value)); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + return Value; } +}; - void launch_search_n() +template +struct count_is_percent_of_size +{ + std::string name() const { - HIP_CHECK(::rocprim::search_n(d_temp_storage, - temp_storage_size, - d_input, - d_output, - size, - count, - d_value, - rocprim::equal_to{}, - stream, - false)); + return "count_is_percent_of_size<" + std::to_string(Value) + ">"; } - - static void run(benchmark::State& state, benchmark_search_n const* _self) + constexpr size_t resolve(size_t size) const { - auto& self = *const_cast(_self); - self.create(); - - // allocate memory - self.launch_search_n(); - HIP_CHECK(hipMallocAsync(&self.d_temp_storage, self.temp_storage_size, self.stream)); - // Warm-up - for(size_t i = 0; i < self.warmup_size; ++i) - { - self.launch_search_n(); - } - HIP_CHECK(hipStreamSynchronize(self.stream)); - - // Run - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(self.start, self.stream)); - - for(size_t i = 0; i < self.batch_size; ++i) - { - self.launch_search_n(); - } + return size * Value / 100; + } +}; - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(self.stop, self.stream)); - HIP_CHECK(hipEventSynchronize(self.stop)); +} // namespace - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, self.start, self.stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - // Clean-up - HIP_CHECK(hipFree(self.d_temp_storage)); - self.d_temp_storage = nullptr; - self.temp_storage_size = 0; - state.SetBytesProcessed(state.iterations() * self.batch_size * self.size - * sizeof(*(self.d_input))); - state.SetItemsProcessed(state.iterations() * self.batch_size * self.size); - self.release(); - } +template +class benchmark_search_n : public benchmark_utils::autotune_interface +{ public: - benchmark_search_n( - const managed_seed _seed, - const hipStream_t _stream, - const size_t _size_byte, - const size_t _count_byte, // for NOISE benchmarks, this is the multiple of count - const size_t _start_pos_byte) noexcept - : seed(_seed) - , stream(_stream) - , size_byte(_size_byte) - , count_byte(_count_byte) - , start_pos_byte(_start_pos_byte) - , value{1} - , input() + void run(benchmark_utils::state&& state) override { - switch(mode) + const auto& stream = state.stream; + const auto& size_byte = state.bytes; + + InputType h_noise{0}; + InputType h_value{1}; + common::device_ptr d_temp_storage; + size_t temp_storage_size = 0; + size_t size; + size_t count; + std::vector input{}; + common::device_ptr d_input; + common::device_ptr d_output(1); + common::device_ptr d_value(std::vector{h_value}, stream); + + size = size_byte / sizeof(InputType); + + count = CountCalculator{}.resolve(size); + size_t cur_tile = 0; + size_t last_tile = size / count - 1; + input = std::vector(size, h_value); + while(cur_tile != last_tile) { - case benchmark_search_n_mode::NORMAL: - { - size = size_byte / sizeof(InputType); - count = count_byte / sizeof(InputType); - start_pos = start_pos_byte / sizeof(InputType); - break; - } - case benchmark_search_n_mode::NOISE: - { - size = size_byte / sizeof(InputType); - count = count_byte; - noise_sequence - = _start_pos_byte == (size_t)-1 - ? search_n_get_item_per_block() - : _start_pos_byte; - - if(size > noise_sequence * count) - { - count = noise_sequence * count; - create_noise = true; - } - break; - } + input[cur_tile * count + count - 1] = h_noise; + ++cur_tile; } + + d_input.store_async(input, stream); + + auto launch_search_n = [&]() + { + HIP_CHECK(::rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + size, + count, + d_value.get(), + rocprim::equal_to{}, + stream, + false)); + }; + + // allocate temp memory + launch_search_n(); + d_temp_storage.resize_async(temp_storage_size, stream); + + state.run([&] { launch_search_n(); }); + + state.set_throughput(size, sizeof(InputType)); } - benchmark::internal::Benchmark* bench_register() const noexcept + std::string name() const override { - return benchmark::RegisterBenchmark( - bench_naming::format_name( - "{lvl:device,algo:search_n,input_type:" + std::string(Traits::name()) - + ",size:" + std::to_string(size) + ",count:" + std::to_string(count) - + ",mode:" + to_string(mode) + ",cfg:default_config}") - .c_str(), - run, - this); + return bench_naming::format_name("{lvl:device,algo:search_n,data_type:" + + std::string(Traits::name()) + + ",count_calculator:" + CountCalculator{}.name() + + ",cfg:" + search_n_config_name() + "}") + .c_str(); } }; -using destructor_t = std::function; -static std::vector destructors; +#ifdef BENCHMARK_CONFIG_TUNING -static void clean_up_benchmarks_search_n() +template +struct device_search_n_benchmark_generator { - for(auto& i : destructors) + static void create(std::vector>& storage) { - i(); + using config = rocprim::search_n_config; + storage.emplace_back( + std::make_unique, config>>()); + storage.emplace_back( + std::make_unique, config>>()); + storage.emplace_back( + std::make_unique, config>>()); + storage.emplace_back( + std::make_unique, config>>()); + storage.emplace_back( + std::make_unique, config>>()); + storage.emplace_back( + std::make_unique< + benchmark_search_n, config>>()); + storage.emplace_back( + std::make_unique< + benchmark_search_n, config>>()); } - destructors = {}; -} - -template -inline void add_one_benchmark_search_n(std::vector& benchmarks, - const managed_seed _seed, - const hipStream_t _stream, - const size_t _size_byte) -{ - // normal - auto half = new benchmark_search_n(_seed, - _stream, - _size_byte, - _size_byte / 2, - _size_byte / 2); - // small count test - auto small_count1 - = new benchmark_search_n(_seed, - _stream, - _size_byte, - 1, // count times - 1); - auto small_count2 - = new benchmark_search_n(_seed, - _stream, - _size_byte, - 1, // count times - 2); - auto small_count4 - = new benchmark_search_n(_seed, - _stream, - _size_byte, - 1, // count times - 4); - auto small_count6 - = new benchmark_search_n(_seed, - _stream, - _size_byte, - 1, // count times - 6); - auto small_count8 - = new benchmark_search_n(_seed, - _stream, - _size_byte, - 1, // count times - 8); - auto small_count10 - = new benchmark_search_n(_seed, - _stream, - _size_byte, - 1, // count times - 10); - auto small_count12 - = new benchmark_search_n(_seed, - _stream, - _size_byte, - 1, // count times - 12); - auto small_count36 - = new benchmark_search_n(_seed, - _stream, - _size_byte, - 1, // count times - 36); - // mid count test - auto mid_count1023 - = new benchmark_search_n(_seed, - _stream, - _size_byte, - 1, // count times - 1023); - auto mid_count2047 - = new benchmark_search_n(_seed, - _stream, - _size_byte, - 1, // count times - 2047); - - auto mid_count4095 - = new benchmark_search_n(_seed, - _stream, - _size_byte, - 1, // count times - 4095); - // big input - auto big_count3 = new benchmark_search_n( - _seed, - _stream, - _size_byte, - 3, // count times - (size_t)-1); // block_size - auto big_count6 = new benchmark_search_n( - _seed, - _stream, - _size_byte, - 6, // count times - (size_t)-1); // block_size - std::vector bs = { - - small_count1->bench_register(), - small_count2->bench_register(), - small_count4->bench_register(), - small_count6->bench_register(), - small_count8->bench_register(), - small_count10->bench_register(), - small_count12->bench_register(), - small_count36->bench_register(), - - mid_count1023->bench_register(), - mid_count2047->bench_register(), - mid_count4095->bench_register(), - - big_count3->bench_register(), - big_count6->bench_register(), - half->bench_register()}; - - destructors.emplace_back( - [=]() - { - delete small_count1; - delete small_count2; - delete small_count4; - delete small_count6; - delete small_count8; - delete small_count10; - delete small_count12; - delete small_count36; - - delete mid_count1023; - delete mid_count2047; - delete mid_count4095; - - delete big_count3; - delete big_count6; - - delete half; - }); - - benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); -} - -template, bool> = true> -inline void add_benchmark_search_n(std::vector& benchmarks, - const managed_seed _seed, - const hipStream_t _stream, - const size_t _size_byte) -{ - add_one_benchmark_search_n(benchmarks, _seed, _stream, _size_byte); - add_benchmark_search_n(benchmarks, _seed, _stream, _size_byte); -} -template, bool> = true> -inline void add_benchmark_search_n(std::vector& benchmarks, - const managed_seed _seed, - const hipStream_t _stream, - const size_t _size_byte) -{ - add_one_benchmark_search_n(benchmarks, _seed, _stream, _size_byte); -} - -using benchmark_search_n_types = type_arr; - -template -struct device_search_n_benchmark_generator -{ - // TODO: add implementation - struct create_search_n_algorithm - {}; - // TODO: add implementation - static void create(std::vector>&) {} }; +#endif // BENCHMARK_CONFIG_TUNING + #endif // ROCPRIM_BENCHMARK_DEVICE_SEARCH_N_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.cpp b/benchmark/benchmark_device_segmented_radix_sort_keys.cpp index bc5ef857f..2fc1da750 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_keys.cpp +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.cpp @@ -20,15 +20,11 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. +#include "benchmark_device_segmented_radix_sort_keys.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #include "../common/utils_data_generation.hpp" -// Google Benchmark -#include - // HIP API #include @@ -46,304 +42,53 @@ #include #endif -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -namespace -{ - -constexpr unsigned int warmup_size = 2; -constexpr size_t min_size = 30000; -constexpr std::array segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; -constexpr std::array segment_lengths{30, 256, 3000, 300000}; -} // namespace - // This benchmark only handles the rocprim::segmented_radix_sort_keys function. The benchmark was separated into two (keys and pairs), // because the binary became too large to link. Runs into a "relocation R_X86_64_PC32 out of range" error. // This happens partially, because of the algorithm has 4 kernels, and decides at runtime which one to call. -template -void run_sort_keys_benchmark(benchmark::State& state, - size_t num_segments, - size_t mean_segment_length, - size_t target_bytes, - const managed_seed& seed, - hipStream_t stream) -{ - using offset_type = int; - using key_type = Key; - - // Calculate the number of elements - size_t target_size = target_bytes / sizeof(key_type); - - std::vector offsets; - offsets.push_back(0); - - static constexpr int iseed = 716; - engine_type gen(iseed); - - std::normal_distribution segment_length_dis(static_cast(mean_segment_length), - 0.1 * mean_segment_length); - - size_t offset = 0; - for(size_t segment_index = 0; segment_index < num_segments;) - { - const double segment_length_candidate = std::round(segment_length_dis(gen)); - if(segment_length_candidate < 0) - { - continue; - } - const offset_type segment_length = static_cast(segment_length_candidate); - offset += segment_length; - offsets.push_back(offset); - ++segment_index; - } - const size_t size = offset; - const size_t segments_count = offsets.size() - 1; - - std::vector keys_input - = get_random_data(size, - common::generate_limits::min(), - common::generate_limits::max(), - seed.get_0()); - - size_t batch_size = 1; - if(size < target_size) - { - batch_size = (target_size + size - 1) / size; - } - - offset_type* d_offsets; - HIP_CHECK(hipMalloc(&d_offsets, offsets.size() * sizeof(offset_type))); - HIP_CHECK(hipMemcpy(d_offsets, - offsets.data(), - offsets.size() * sizeof(offset_type), - hipMemcpyHostToDevice)); - - key_type* d_keys_input; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); - HIP_CHECK( - hipMemcpy(d_keys_input, keys_input.data(), size * sizeof(key_type), hipMemcpyHostToDevice)); - - void* d_temporary_storage = nullptr; - size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - size, - segments_count, - d_offsets, - d_offsets + 1, - 0, - sizeof(key_type) * 8, - stream, - false)); - - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - size, - segments_count, - d_offsets, - d_offsets + 1, - 0, - sizeof(key_type) * 8, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) - { - HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - size, - segments_count, - d_offsets, - d_offsets + 1, - 0, - sizeof(key_type) * 8, - stream, - false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_offsets)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); -} - template -void add_sort_keys_benchmarks(std::vector& benchmarks, - size_t max_bytes, - size_t min_size, - size_t target_size, - const managed_seed& seed, - hipStream_t stream) +void add_benchmarks(benchmark_utils::executor& executor, size_t bytes) { - // Calculate the number of elements - size_t max_size = max_bytes / sizeof(KeyT); + constexpr std::array segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; + constexpr std::array segment_lengths{30, 256, 3000, 300000}; + + constexpr size_t min_size = 30000; + size_t max_size = bytes / sizeof(KeyT); - std::string key_name = Traits::name(); - std::string value_name = Traits::name(); for(const auto segment_count : segment_counts) { for(const auto segment_length : segment_lengths) { + // This check is also present in device_segmented_radix_sort_keys_benchmark its run() + // We need it here to prevent Google Benchmark causing an infinite loop const auto number_of_elements = segment_count * segment_length; - if(number_of_elements > max_size || number_of_elements < min_size) + if(number_of_elements < min_size || number_of_elements > max_size) { continue; } - benchmarks.push_back(benchmark::RegisterBenchmark( - bench_naming::format_name( - "{lvl:device,algo:radix_sort_segmented,key_type:" + key_name + ",value_type:" - + value_name + ",segment_count:" + std::to_string(segment_count) - + ",segment_length:" + std::to_string(segment_length) + ",cfg:default_config}") - .c_str(), - [=](benchmark::State& state) - { - run_sort_keys_benchmark(state, - segment_count, - segment_length, - target_size, - seed, - stream); - })); + + executor.queue_instance( + device_segmented_radix_sort_keys_benchmark(segment_count, segment_length)); } } } int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); + size_t bytes = 128 * benchmark_utils::MiB; -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default + benchmark_utils::executor executor(argc, argv, bytes, 10, 5); - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks; -#ifdef BENCHMARK_CONFIG_TUNING - (void)min_size; - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - min_size, - seed, - stream); -#else - add_sort_keys_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); - add_sort_keys_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); - add_sort_keys_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); - add_sort_keys_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); - add_sort_keys_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); - add_sort_keys_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); - add_sort_keys_benchmarks(benchmarks, - bytes, - min_size, - bytes / 2, - seed, - stream); - add_sort_keys_benchmarks(benchmarks, - bytes, - min_size, - bytes / 2, - seed, - stream); +#ifndef BENCHMARK_CONFIG_TUNING + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); #endif - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in index 6e2875ca1..4ff4147d8 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -29,10 +29,9 @@ namespace { -auto benchmarks = config_autotune_register::create_bulk( - device_segmented_radix_sort_benchmark_generator< - @LongBits@, - 0, +auto unused = benchmark_utils::executor::queue_autotune( + device_segmented_radix_sort_keys_benchmark_generator< + @RadixBits@, @BlockSize@, @ItemsPerThread@, @WarpSmallLWS@, diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp index d4f9d3fab..0b9260295 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp @@ -26,6 +26,7 @@ #include "benchmark_utils.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -45,7 +46,6 @@ #include #include #include -#include #include template @@ -67,8 +67,7 @@ std::string config_name() const rocprim::detail::segmented_radix_sort_config_params config = Config(); return "{bs:" + std::to_string(config.kernel_config.block_size) + ",ipt:" + std::to_string(config.kernel_config.items_per_thread) - + ",lrb:" + std::to_string(config.long_radix_bits) - + ",srb:" + std::to_string(config.short_radix_bits) + + ",rb:" + std::to_string(config.radix_bits) + ",eupws:" + std::to_string(config.enable_unpartitioned_warp_sort) + ",wsc:" + warp_sort_config_name(config.warp_sort_config) + "}"; } @@ -79,28 +78,49 @@ inline std::string config_name() return "default_config"; } -template -struct device_segmented_radix_sort_benchmark : public config_autotune_interface +template +struct device_segmented_radix_sort_keys_benchmark : public benchmark_utils::autotune_interface { +private: + std::vector segment_counts; + std::vector segment_lengths; + size_t total_size; + +public: + device_segmented_radix_sort_keys_benchmark(size_t segment_count, size_t segment_length) + { + segment_counts.push_back(segment_count); + segment_lengths.push_back(segment_length); + } + + device_segmented_radix_sort_keys_benchmark(const std::vector& segment_counts, + const std::vector& segment_lengths) + { + this->segment_counts = segment_counts; + this->segment_lengths = segment_lengths; + } + std::string name() const override { using namespace std::string_literals; - const rocprim::detail::segmented_radix_sort_config_params config = Config(); return bench_naming::format_name( "{lvl:device,algo:segmented_radix_sort,key_type:" + std::string(Traits::name()) - + ",value_type:empty_type" + ",cfg:" + config_name() + "}"); + + ",value_type:empty_type" + + (segment_counts.size() == 1 ? ",segment_count:"s + std::to_string(segment_counts[0]) + : ""s) + + (segment_lengths.size() == 1 + ? ",segment_length:"s + std::to_string(segment_lengths[0]) + : ""s) + + ",cfg:" + config_name() + "}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - - void run_benchmark(benchmark::State& state, - size_t num_segments, - size_t mean_segment_length, - size_t target_size, - const managed_seed& seed, - hipStream_t stream) const + void run_benchmark(benchmark_utils::state&& state, + size_t num_segments, + size_t mean_segment_length) { + const auto& stream = state.stream; + const auto& seed = state.seed; + using offset_type = int; using key_type = Key; @@ -136,145 +156,87 @@ struct device_segmented_radix_sort_benchmark : public config_autotune_interface common::generate_limits::max(), seed.get_0()); - size_t batch_size = 1; - if(size < target_size) - { - batch_size = (target_size + size - 1) / size; - } + common::device_ptr d_offsets(offsets); + + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_keys_output(size); - offset_type* d_offsets; - HIP_CHECK(hipMalloc(&d_offsets, offsets.size() * sizeof(offset_type))); - HIP_CHECK(hipMemcpy(d_offsets, - offsets.data(), - offsets.size() * sizeof(offset_type), - hipMemcpyHostToDevice)); - - key_type* d_keys_input; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice)); - - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, + HIP_CHECK(rocprim::segmented_radix_sort_keys(nullptr, temporary_storage_bytes, - d_keys_input, - d_keys_output, + d_keys_input.get(), + d_keys_output.get(), size, segments_count, - d_offsets, - d_offsets + 1, + d_offsets.get(), + d_offsets.get() + 1, 0, sizeof(key_type) * 8, stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - size, - segments_count, - d_offsets, - d_offsets + 1, - 0, - sizeof(key_type) * 8, - stream, - false)); - } + common::device_ptr d_temporary_storage(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, + HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input, - d_keys_output, + d_keys_input.get(), + d_keys_output.get(), size, segments_count, - d_offsets, - d_offsets + 1, + d_offsets.get(), + d_offsets.get() + 1, 0, sizeof(key_type) * 8, stream, false)); - } + }); - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_offsets)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); + total_size += size; } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { - // Calculate the number of elements - size_t size = bytes / sizeof(Key); + total_size = 0; - constexpr std::array - segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; - constexpr std::array segment_lengths{30, 256, 3000, 300000}; - - for(const auto segment_count : segment_counts) + if(segment_counts.size() == 1) + { + run_benchmark(std::forward(state), + segment_counts[0], + segment_lengths[0]); + } + else { - for(const auto segment_length : segment_lengths) + state.accumulate_total_gbench_iterations_every_run(); + + constexpr size_t min_size = 300000; + constexpr size_t max_size = 33554432; + + for(const auto segment_count : segment_counts) { - const auto number_of_elements = segment_count * segment_length; - if(number_of_elements > 33554432 || number_of_elements < 300000) + for(const auto segment_length : segment_lengths) { - continue; + const auto number_of_elements = segment_count * segment_length; + if(number_of_elements < min_size || number_of_elements > max_size) + { + continue; + } + + run_benchmark(std::forward(state), + segment_count, + segment_length); } - - run_benchmark(state, segment_count, segment_length, size, seed, stream); } } + + state.set_throughput(total_size, sizeof(Key)); } }; -template class T, bool enable, Tp... Idx> -struct decider; - -template -struct device_segmented_radix_sort_benchmark_generator +struct device_segmented_radix_sort_keys_benchmark_generator { template - static auto __create(std::vector>& storage) -> - typename std::enable_if<(key_size * BlockSize * ItemsPerThread < TUNING_SHARED_MEMORY_MAX), - void>::type + static auto _create(std::vector>& storage) + -> std::enable_if_t<(key_size * BlockSize * ItemsPerThread < TUNING_SHARED_MEMORY_MAX)> { - storage.emplace_back(std::make_unique segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; + const std::vector segment_lengths{30, 256, 3000, 300000}; + + storage.emplace_back(std::make_unique, rocprim::WarpSortConfig, - UnpartitionWarpAllowed>>>()); + UnpartitionWarpAllowed>>>(segment_counts, segment_lengths)); } + template - static auto __create(std::vector>&) -> - typename std::enable_if::type + static auto _create(std::vector>&) + -> std::enable_if_t {} - static void create(std::vector>& storage) - { - __create(storage); - } -}; -template class T, Tp... Idx> -struct decider -{ - inline static void - do_the_thing(std::vector>& storage) + static void create(std::vector>& storage) { - static_for_each, T>(storage); + _create(storage); } }; -template class T, Tp... Idx> -struct decider -{ - inline static void - do_the_thing(std::vector>& /*storage*/) - {} -}; - #endif // ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_KEYS_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp b/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp index bbbb6169d..5ea0d165e 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp @@ -20,18 +20,14 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. +#include "benchmark_device_segmented_radix_sort_pairs.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #include "../common/utils_data_generation.hpp" #ifndef BENCHMARK_CONFIG_TUNING #include "../common/utils_custom_type.hpp" #endif -// Google Benchmark -#include "benchmark/benchmark.h" - // HIP API #include @@ -52,354 +48,58 @@ #include #endif -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -namespace -{ - -constexpr unsigned int warmup_size = 2; -constexpr size_t min_size = 30000; -constexpr std::array segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; -constexpr std::array segment_lengths{30, 256, 3000, 300000}; -} // namespace - // This benchmark only handles the rocprim::segmented_radix_sort_pairs function. The benchmark was separated into two (keys and pairs), // because the binary became too large to link. Runs into a "relocation R_X86_64_PC32 out of range" error. // This happens partially, because of the algorithm has 4 kernels, and decides at runtime which one to call. -template -void run_sort_pairs_benchmark(benchmark::State& state, - size_t num_segments, - size_t mean_segment_length, - size_t target_bytes, - const managed_seed& seed, - hipStream_t stream) -{ - using offset_type = int; - using key_type = Key; - using value_type = Value; - - // Calculate the number of elements - size_t target_size = target_bytes / sizeof(key_type); - - // Generate data - std::vector offsets; - offsets.push_back(0); - - static constexpr int iseed = 716; - engine_type gen(iseed); - - std::normal_distribution segment_length_dis(static_cast(mean_segment_length), - 0.1 * mean_segment_length); - - size_t offset = 0; - for(size_t segment_index = 0; segment_index < num_segments;) - { - const double segment_length_candidate = std::round(segment_length_dis(gen)); - if(segment_length_candidate < 0) - { - continue; - } - const offset_type segment_length = static_cast(segment_length_candidate); - offset += segment_length; - offsets.push_back(offset); - ++segment_index; - } - const size_t size = offset; - const size_t segments_count = offsets.size() - 1; - - std::vector keys_input - = get_random_data(size, - common::generate_limits::min(), - common::generate_limits::max(), - seed.get_0()); - - size_t batch_size = 1; - if(size < target_size) - { - batch_size = (target_size + size - 1) / size; - } - - std::vector values_input(size); - std::iota(values_input.begin(), values_input.end(), 0); - - offset_type* d_offsets; - HIP_CHECK(hipMalloc(&d_offsets, (segments_count + 1) * sizeof(offset_type))); - HIP_CHECK(hipMemcpy(d_offsets, - offsets.data(), - (segments_count + 1) * sizeof(offset_type), - hipMemcpyHostToDevice)); - - key_type* d_keys_input; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); - HIP_CHECK( - hipMemcpy(d_keys_input, keys_input.data(), size * sizeof(key_type), hipMemcpyHostToDevice)); - - value_type* d_values_input; - value_type* d_values_output; - HIP_CHECK(hipMalloc(&d_values_input, size * sizeof(value_type))); - HIP_CHECK(hipMalloc(&d_values_output, size * sizeof(value_type))); - HIP_CHECK(hipMemcpy(d_values_input, - values_input.data(), - size * sizeof(value_type), - hipMemcpyHostToDevice)); - - void* d_temporary_storage = nullptr; - size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - segments_count, - d_offsets, - d_offsets + 1, - 0, - sizeof(key_type) * 8, - stream, - false)); - - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - segments_count, - d_offsets, - d_offsets + 1, - 0, - sizeof(key_type) * 8, - stream, - false)); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) - { - HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - segments_count, - d_offsets, - d_offsets + 1, - 0, - sizeof(key_type) * 8, - stream, - false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size - * (sizeof(key_type) + sizeof(value_type))); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_offsets)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); - HIP_CHECK(hipFree(d_values_input)); - HIP_CHECK(hipFree(d_values_output)); -} - template -void add_sort_pairs_benchmarks(std::vector& benchmarks, - size_t max_bytes, - size_t min_size, - size_t target_size, - const managed_seed& seed, - hipStream_t stream) +void add_benchmarks(benchmark_utils::executor& executor, size_t bytes) { - // Calculate the number of elements - size_t max_size = max_bytes / sizeof(KeyT); + constexpr std::array segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; + constexpr std::array segment_lengths{30, 256, 3000, 300000}; + + constexpr size_t min_size = 30000; + size_t max_size = bytes / sizeof(KeyT); - std::string key_name = Traits::name(); - std::string value_name = Traits::name(); for(const auto segment_count : segment_counts) { for(const auto segment_length : segment_lengths) { + // This check is also present in device_segmented_radix_sort_pairs_benchmark its run() + // We need it here to prevent Google Benchmark causing an infinite loop const auto number_of_elements = segment_count * segment_length; - if(number_of_elements > max_size || number_of_elements < min_size) + if(number_of_elements < min_size || number_of_elements > max_size) { continue; } - benchmarks.push_back(benchmark::RegisterBenchmark( - bench_naming::format_name( - "{lvl:device,algo:radix_sort_segmented,key_type:" + key_name + ",value_type:" - + value_name + ",segment_count:" + std::to_string(segment_count) - + ",segment_length:" + std::to_string(segment_length) + ",cfg:default_config}") - .c_str(), - [=](benchmark::State& state) - { - run_sort_pairs_benchmark(state, - segment_count, - segment_length, - target_size, - seed, - stream); - })); + + executor.queue_instance( + device_segmented_radix_sort_pairs_benchmark(segment_count, + segment_length)); } } } int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default + size_t bytes = 128 * benchmark_utils::MiB; - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); + benchmark_utils::executor executor(argc, argv, bytes, 10, 5); - // Add benchmarks - std::vector benchmarks; -#ifdef BENCHMARK_CONFIG_TUNING - (void)min_size; - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else +#ifndef BENCHMARK_CONFIG_TUNING using custom_float2 = common::custom_type; using custom_double2 = common::custom_type; - add_sort_pairs_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); - add_sort_pairs_benchmarks(benchmarks, - bytes, - min_size, - bytes / 2, - seed, - stream); - add_sort_pairs_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); - add_sort_pairs_benchmarks(benchmarks, - bytes, - min_size, - bytes / 2, - seed, - stream); - add_sort_pairs_benchmarks(benchmarks, - bytes, - min_size, - bytes / 2, - seed, - stream); - add_sort_pairs_benchmarks(benchmarks, - bytes, - min_size, - bytes / 2, - seed, - stream); - add_sort_pairs_benchmarks(benchmarks, - bytes, - min_size, - bytes / 2, - seed, - stream); - add_sort_pairs_benchmarks(benchmarks, - bytes, - min_size, - bytes / 2, - seed, - stream); - add_sort_pairs_benchmarks(benchmarks, - bytes, - min_size, - bytes / 2, - seed, - stream); -#endif - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); + add_benchmarks(executor, bytes); +#endif - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in index 5618b3b26..3642eed2c 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -29,10 +29,9 @@ namespace { -auto benchmarks = config_autotune_register::create_bulk( - device_segmented_radix_sort_benchmark_generator< - @LongBits@, - 8, +auto unused = benchmark_utils::executor::queue_autotune( + device_segmented_radix_sort_pairs_benchmark_generator< + @RadixBits@, @BlockSize@, @ItemsPerThread@, @WarpSmallLWS@, diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp index a5a9e547a..58bbdd57b 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp @@ -26,6 +26,7 @@ #include "benchmark_utils.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" // Google Benchmark #include @@ -45,7 +46,6 @@ #include #include #include -#include #include template @@ -67,8 +67,7 @@ std::string config_name() const rocprim::detail::segmented_radix_sort_config_params config = Config(); return "{bs:" + std::to_string(config.kernel_config.block_size) + ",ipt:" + std::to_string(config.kernel_config.items_per_thread) - + ",lrb:" + std::to_string(config.long_radix_bits) - + ",srb:" + std::to_string(config.short_radix_bits) + + ",rb:" + std::to_string(config.radix_bits) + ",eupws:" + std::to_string(config.enable_unpartitioned_warp_sort) + ",wsc:" + warp_sort_config_name(config.warp_sort_config) + "}"; } @@ -79,29 +78,49 @@ inline std::string config_name() return "default_config"; } -template -struct device_segmented_radix_sort_benchmark : public config_autotune_interface +template +struct device_segmented_radix_sort_pairs_benchmark : public benchmark_utils::autotune_interface { +private: + std::vector segment_counts; + std::vector segment_lengths; + size_t total_size; + +public: + device_segmented_radix_sort_pairs_benchmark(size_t segment_count, size_t segment_length) + { + segment_counts.push_back(segment_count); + segment_lengths.push_back(segment_length); + } + + device_segmented_radix_sort_pairs_benchmark(const std::vector& segment_counts, + const std::vector& segment_lengths) + { + this->segment_counts = segment_counts; + this->segment_lengths = segment_lengths; + } + std::string name() const override { using namespace std::string_literals; - const rocprim::detail::segmented_radix_sort_config_params config = Config(); - return bench_naming::format_name("{lvl:device,algo:segmented_radix_sort,key_type:" - + std::string(Traits::name()) - + ",value_type:" + std::string(Traits::name()) - + ",cfg:" + config_name() + "}"); + return bench_naming::format_name( + "{lvl:device,algo:segmented_radix_sort,key_type:" + std::string(Traits::name()) + + ",value_type:" + std::string(Traits::name()) + + (segment_counts.size() == 1 ? ",segment_count:"s + std::to_string(segment_counts[0]) + : ""s) + + (segment_lengths.size() == 1 + ? ",segment_length:"s + std::to_string(segment_lengths[0]) + : ""s) + + ",cfg:" + config_name() + "}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - - void run_benchmark(benchmark::State& state, - size_t num_segments, - size_t mean_segment_length, - size_t target_size, - const managed_seed& seed, - hipStream_t stream) const + void run_benchmark(benchmark_utils::state&& state, + size_t num_segments, + size_t mean_segment_length) { + const auto& stream = state.stream; + const auto& seed = state.seed; + using offset_type = int; using key_type = Key; using value_type = Value; @@ -144,162 +163,94 @@ struct device_segmented_radix_sort_benchmark : public config_autotune_interface common::generate_limits::max(), seed.get_0()); - size_t batch_size = 1; - if(size < target_size) - { - batch_size = (target_size + size - 1) / size; - } + common::device_ptr d_offsets(offsets); + + common::device_ptr d_keys_input(keys_input); + common::device_ptr d_keys_output(size); + + common::device_ptr d_values_input(values_input); + common::device_ptr d_values_output(size); - offset_type* d_offsets; - HIP_CHECK(hipMalloc(&d_offsets, offsets.size() * sizeof(offset_type))); - HIP_CHECK(hipMemcpy(d_offsets, - offsets.data(), - offsets.size() * sizeof(offset_type), - hipMemcpyHostToDevice)); - - key_type* d_keys_input; - key_type* d_keys_output; - HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); - HIP_CHECK(hipMemcpy(d_keys_input, - keys_input.data(), - size * sizeof(key_type), - hipMemcpyHostToDevice)); - - value_type* d_values_input; - value_type* d_values_output; - HIP_CHECK(hipMalloc(&d_values_input, size * sizeof(value_type))); - HIP_CHECK(hipMalloc(&d_values_output, size * sizeof(value_type))); - HIP_CHECK(hipMemcpy(d_values_input, - values_input.data(), - size * sizeof(value_type), - hipMemcpyHostToDevice)); - - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, + HIP_CHECK(rocprim::segmented_radix_sort_pairs(nullptr, temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, + d_keys_input.get(), + d_keys_output.get(), + d_values_input.get(), + d_values_output.get(), size, segments_count, - d_offsets, - d_offsets + 1, + d_offsets.get(), + d_offsets.get() + 1, 0, sizeof(key_type) * 8, stream, false)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, - temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, - size, - segments_count, - d_offsets, - d_offsets + 1, - 0, - sizeof(key_type) * 8, - stream, - false)); - } + common::device_ptr d_temporary_storage(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + state.run( + [&] { - HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, + HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage.get(), temporary_storage_bytes, - d_keys_input, - d_keys_output, - d_values_input, - d_values_output, + d_keys_input.get(), + d_keys_output.get(), + d_values_input.get(), + d_values_output.get(), size, segments_count, - d_offsets, - d_offsets + 1, + d_offsets.get(), + d_offsets.get() + 1, 0, sizeof(key_type) * 8, stream, false)); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); + }); - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_offsets)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_keys_output)); - HIP_CHECK(hipFree(d_values_input)); - HIP_CHECK(hipFree(d_values_output)); + total_size += size; } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { - // Calculate the number of elements - size_t size = bytes / sizeof(Key); - - constexpr std::array - segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; - constexpr std::array segment_lengths{30, 256, 3000, 300000}; + total_size = 0; - for(const auto segment_count : segment_counts) + if(segment_counts.size() == 1) { - for(const auto segment_length : segment_lengths) + run_benchmark(std::forward(state), + segment_counts[0], + segment_lengths[0]); + } + else + { + state.accumulate_total_gbench_iterations_every_run(); + + constexpr size_t min_size = 300000; + constexpr size_t max_size = 33554432; + + for(const auto segment_count : segment_counts) { - const auto number_of_elements = segment_count * segment_length; - if(number_of_elements > 33554432 || number_of_elements < 300000) + for(const auto segment_length : segment_lengths) { - continue; + const auto number_of_elements = segment_count * segment_length; + if(number_of_elements < min_size || number_of_elements > max_size) + { + continue; + } + + run_benchmark(std::forward(state), + segment_count, + segment_length); } - - run_benchmark(state, segment_count, segment_length, size, seed, stream); } } + + state.set_throughput(total_size, sizeof(Key) + sizeof(Value)); } }; -template class T, bool enable, Tp... Idx> -struct decider; - -template -struct device_segmented_radix_sort_benchmark_generator +struct device_segmented_radix_sort_pairs_benchmark_generator { template - static auto __create(std::vector>& storage) -> - typename std::enable_if<((key_size + value_type) * BlockSize * ItemsPerThread - <= TUNING_SHARED_MEMORY_MAX), - void>::type + static auto _create(std::vector>& storage) + -> std::enable_if_t<((key_size + value_type) * BlockSize * ItemsPerThread + <= TUNING_SHARED_MEMORY_MAX)> { - storage.emplace_back(std::make_unique segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; + const std::vector segment_lengths{30, 256, 3000, 300000}; + + storage.emplace_back(std::make_unique, rocprim::WarpSortConfig, - UnpartitionWarpAllowed>>>()); + UnpartitionWarpAllowed>>>(segment_counts, segment_lengths)); } + template - static auto __create(std::vector>&) -> - typename std::enable_if::type + static auto _create(std::vector>&) + -> std::enable_if_t {} - static void create(std::vector>& storage) - { - __create(storage); - } -}; -template class T, Tp... Idx> -struct decider -{ - inline static void - do_the_thing(std::vector>& storage) + static void create(std::vector>& storage) { - static_for_each, T>(storage); + _create(storage); } }; -template class T, Tp... Idx> -struct decider -{ - inline static void - do_the_thing(std::vector>& /*storage*/) - {} -}; - #endif // ROCPRIM_BENCHMARK_DEVICE_SEGMENTED_RADIX_SORT_PAIRS_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_segmented_reduce.cpp b/benchmark/benchmark_device_segmented_reduce.cpp index 2dc834675..abca2a79c 100644 --- a/benchmark/benchmark_device_segmented_reduce.cpp +++ b/benchmark/benchmark_device_segmented_reduce.cpp @@ -22,14 +22,9 @@ #include "benchmark_device_segmented_reduce.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #include "../common/utils_custom_type.hpp" -// Google Benchmark -#include - // HIP API #include @@ -46,33 +41,8 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define ADD_BENCHMARK(T, SEGMENTS, INSTANCE) \ - benchmark::internal::Benchmark* benchmark = benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:device,algo:reduce_segmented,key_type:" #T \ - ",segment_count:" \ - + std::to_string(SEGMENTS) + ",cfg:default_config}") \ - .c_str(), \ - [INSTANCE](benchmark::State& state, \ - size_t _desired_segments, \ - size_t _size, \ - const managed_seed& _seed, \ - hipStream_t _stream) \ - { INSTANCE.run_benchmark(state, _desired_segments, _size, _seed, _stream); }, \ - SEGMENTS, \ - bytes, \ - seed, \ - stream); - -#define CREATE_BENCHMARK(T, SEGMENTS) \ - { \ - const device_segmented_reduce_benchmark instance; \ - ADD_BENCHMARK(T, SEGMENTS, instance) \ - benchmarks.emplace_back(benchmark); \ - } +#define CREATE_BENCHMARK(T, SEGMENTS) \ + executor.queue_instance(device_segmented_reduce_benchmark(SEGMENTS)); #define BENCHMARK_TYPE(type) \ CREATE_BENCHMARK(type, 1) \ @@ -81,11 +51,11 @@ const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; CREATE_BENCHMARK(type, 1000) \ CREATE_BENCHMARK(type, 10000) -void add_benchmarks(std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +int main(int argc, char* argv[]) { + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); + +#ifndef BENCHMARK_CONFIG_TUNING using custom_float2 = common::custom_type; using custom_double2 = common::custom_type; @@ -99,80 +69,7 @@ void add_benchmarks(std::vector& benchmarks, BENCHMARK_TYPE(custom_double2) BENCHMARK_TYPE(rocprim::int128_t) BENCHMARK_TYPE(rocprim::uint128_t) -} - -int main(int argc, char* argv[]) -{ - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - // fixed seed as a random seed adds a lot of variance - parser.set_optional("seed", "seed", "321", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else - add_benchmarks(benchmarks, bytes, seed, stream); #endif - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_segmented_reduce.parallel.cpp.in b/benchmark/benchmark_device_segmented_reduce.parallel.cpp.in index ebb05ca7b..f3ac41745 100644 --- a/benchmark/benchmark_device_segmented_reduce.parallel.cpp.in +++ b/benchmark/benchmark_device_segmented_reduce.parallel.cpp.in @@ -26,7 +26,7 @@ #include "benchmark_device_segmented_reduce.parallel.hpp" namespace { - auto benchmark = config_autotune_register::create, rocprim::reduce_config<@BlockSize@u, @ItemsPerThread@u, rocprim::block_reduce_algorithm::using_warp_reduce>>>(); } diff --git a/benchmark/benchmark_device_segmented_reduce.parallel.hpp b/benchmark/benchmark_device_segmented_reduce.parallel.hpp index c42b0347e..074abcc35 100644 --- a/benchmark/benchmark_device_segmented_reduce.parallel.hpp +++ b/benchmark/benchmark_device_segmented_reduce.parallel.hpp @@ -25,6 +25,8 @@ #include "benchmark_utils.hpp" +#include "../common/utils_device_ptr.hpp" + // Google Benchmark #include @@ -37,6 +39,7 @@ #include #include #include +#include #include #include @@ -71,28 +74,42 @@ inline std::string config_name() return "default_config"; } -template, typename Config = rocprim::default_config> -struct device_segmented_reduce_benchmark : public config_autotune_interface +struct device_segmented_reduce_benchmark : public benchmark_utils::autotune_interface { +private: + std::vector desired_segments; + size_t total_size; - std::string name() const override +public: + device_segmented_reduce_benchmark() + { + this->desired_segments = std::vector{1, 10, 100, 1000, 10000}; + } + + device_segmented_reduce_benchmark(size_t desired_segment) { - return bench_naming::format_name("{lvl:device,algo:segmented_reduce,key_type:" - + std::string(Traits::name()) - + ",cfg:" + config_name() + "}"); + desired_segments.push_back(desired_segment); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; + std::string name() const override + { + return bench_naming::format_name( + "{lvl:device,algo:segmented_reduce,key_type:" + std::string(Traits::name()) + + (desired_segments.size() == 1 + ? ",segment_count:" + std::to_string(desired_segments[0]) + : "") + + ",cfg:" + config_name() + "}"); + } - void run_benchmark(benchmark::State& state, - size_t desired_segment, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const + void run_benchmark(benchmark_utils::state&& state, size_t desired_segment) { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using offset_type = int; using value_type = T; @@ -120,118 +137,59 @@ struct device_segmented_reduce_benchmark : public config_autotune_interface std::vector values_input(size); std::iota(values_input.begin(), values_input.end(), 0); - offset_type* d_offsets; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_offsets), - (segments_count + 1) * sizeof(offset_type))); - HIP_CHECK(hipMemcpy(d_offsets, - offsets.data(), - (segments_count + 1) * sizeof(offset_type), - hipMemcpyHostToDevice)); - - value_type* d_values_input; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_input), size * sizeof(value_type))); - HIP_CHECK(hipMemcpy(d_values_input, - values_input.data(), - size * sizeof(value_type), - hipMemcpyHostToDevice)); - - value_type* d_aggregates_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_aggregates_output), - segments_count * sizeof(value_type))); + common::device_ptr d_offsets(offsets); + + common::device_ptr d_values_input(values_input); + + common::device_ptr d_aggregates_output(segments_count); rocprim::plus reduce_op; value_type init(0); - void* d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; - HIP_CHECK(rp::segmented_reduce(d_temporary_storage, + HIP_CHECK(rp::segmented_reduce(nullptr, temporary_storage_bytes, - d_values_input, - d_aggregates_output, + d_values_input.get(), + d_aggregates_output.get(), segments_count, - d_offsets, - d_offsets + 1, + d_offsets.get(), + d_offsets.get() + 1, reduce_op, init, stream)); - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; i++) - { - HIP_CHECK(rp::segmented_reduce(d_temporary_storage, - temporary_storage_bytes, - d_values_input, - d_aggregates_output, - segments_count, - d_offsets, - d_offsets + 1, - reduce_op, - init, - stream)); - } + common::device_ptr d_temporary_storage(temporary_storage_bytes); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; i++) + state.run( + [&] { - HIP_CHECK(rp::segmented_reduce(d_temporary_storage, + HIP_CHECK(rp::segmented_reduce(d_temporary_storage.get(), temporary_storage_bytes, - d_values_input, - d_aggregates_output, + d_values_input.get(), + d_aggregates_output.get(), segments_count, - d_offsets, - d_offsets + 1, + d_offsets.get(), + d_offsets.get() + 1, reduce_op, init, stream)); - } + }); - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(value_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_offsets)); - HIP_CHECK(hipFree(d_values_input)); - HIP_CHECK(hipFree(d_aggregates_output)); + total_size += size; } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { - constexpr std::array desired_segments{1, 10, 100, 1000, 10000}; + total_size = 0; for(const auto desired_segment : desired_segments) { - run_benchmark(state, desired_segment, bytes, seed, stream); + run_benchmark(std::forward(state), desired_segment); } + + state.set_throughput(total_size, sizeof(T)); } }; diff --git a/benchmark/benchmark_device_select.cpp b/benchmark/benchmark_device_select.cpp index 9cfd910e8..145518d50 100644 --- a/benchmark/benchmark_device_select.cpp +++ b/benchmark/benchmark_device_select.cpp @@ -22,225 +22,118 @@ #include "benchmark_device_select.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" -#ifndef BENCHMARK_CONFIG_TUNING - #include "../common/utils_custom_type.hpp" -#endif +#define CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(T, F, p) \ + executor.queue_instance( \ + device_select_predicated_flag_benchmark()); -// Google Benchmark -#include "benchmark/benchmark.h" +#define CREATE_SELECT_FLAG_BENCHMARK(T, F, p) \ + executor.queue_instance(device_select_flag_benchmark()); -// HIP API -#include +#define CREATE_SELECT_PREDICATE_BENCHMARK(T, p) \ + executor.queue_instance(device_select_predicate_benchmark()); -// rocPRIM -#ifndef BENCHMARK_CONFIG_TUNING - #include - #include -#endif +#define CREATE_UNIQUE_BENCHMARK(T, p) \ + executor.queue_instance(device_select_unique_benchmark()); -#include -#include -#include -#ifndef BENCHMARK_CONFIG_TUNING - #include -#endif +#define CREATE_UNIQUE_BY_KEY_BENCHMARK(K, V, p) \ + executor.queue_instance( \ + device_select_unique_by_key_benchmark()); -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -#define CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(T, F, p) \ - { \ - const device_select_predicated_flag_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_SELECT_FLAG_BENCHMARK(T, F, p) \ - { \ - const device_select_flag_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_SELECT_PREDICATE_BENCHMARK(T, p) \ - { \ - const device_select_predicate_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_UNIQUE_BENCHMARK(T, p) \ - { \ - const device_select_unique_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define CREATE_UNIQUE_BY_KEY_BENCHMARK(K, V, p) \ - { \ - const device_select_unique_by_key_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } - -#define BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(type, value) \ - CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(type, value, select_probability::p005); \ - CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(type, value, select_probability::p025); \ - CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(type, value, select_probability::p050); \ +#define BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(type, value) \ + CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(type, value, select_probability::p005) \ + CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(type, value, select_probability::p025) \ + CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(type, value, select_probability::p050) \ CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(type, value, select_probability::p075) -#define BENCHMARK_SELECT_FLAG_TYPE(type, value) \ - CREATE_SELECT_FLAG_BENCHMARK(type, value, select_probability::p005); \ - CREATE_SELECT_FLAG_BENCHMARK(type, value, select_probability::p025); \ - CREATE_SELECT_FLAG_BENCHMARK(type, value, select_probability::p050); \ +#define BENCHMARK_SELECT_FLAG_TYPE(type, value) \ + CREATE_SELECT_FLAG_BENCHMARK(type, value, select_probability::p005) \ + CREATE_SELECT_FLAG_BENCHMARK(type, value, select_probability::p025) \ + CREATE_SELECT_FLAG_BENCHMARK(type, value, select_probability::p050) \ CREATE_SELECT_FLAG_BENCHMARK(type, value, select_probability::p075) -#define BENCHMARK_SELECT_PREDICATE_TYPE(type) \ - CREATE_SELECT_PREDICATE_BENCHMARK(type, select_probability::p005); \ - CREATE_SELECT_PREDICATE_BENCHMARK(type, select_probability::p025); \ - CREATE_SELECT_PREDICATE_BENCHMARK(type, select_probability::p050); \ +#define BENCHMARK_SELECT_PREDICATE_TYPE(type) \ + CREATE_SELECT_PREDICATE_BENCHMARK(type, select_probability::p005) \ + CREATE_SELECT_PREDICATE_BENCHMARK(type, select_probability::p025) \ + CREATE_SELECT_PREDICATE_BENCHMARK(type, select_probability::p050) \ CREATE_SELECT_PREDICATE_BENCHMARK(type, select_probability::p075) -#define BENCHMARK_UNIQUE_TYPE(type) \ - CREATE_UNIQUE_BENCHMARK(type, select_probability::p005); \ - CREATE_UNIQUE_BENCHMARK(type, select_probability::p025); \ - CREATE_UNIQUE_BENCHMARK(type, select_probability::p050); \ +#define BENCHMARK_UNIQUE_TYPE(type) \ + CREATE_UNIQUE_BENCHMARK(type, select_probability::p005) \ + CREATE_UNIQUE_BENCHMARK(type, select_probability::p025) \ + CREATE_UNIQUE_BENCHMARK(type, select_probability::p050) \ CREATE_UNIQUE_BENCHMARK(type, select_probability::p075) -#define BENCHMARK_UNIQUE_BY_KEY_TYPE(K, V) \ - CREATE_UNIQUE_BY_KEY_BENCHMARK(K, V, select_probability::p005); \ - CREATE_UNIQUE_BY_KEY_BENCHMARK(K, V, select_probability::p025); \ - CREATE_UNIQUE_BY_KEY_BENCHMARK(K, V, select_probability::p050); \ +#define BENCHMARK_UNIQUE_BY_KEY_TYPE(K, V) \ + CREATE_UNIQUE_BY_KEY_BENCHMARK(K, V, select_probability::p005) \ + CREATE_UNIQUE_BY_KEY_BENCHMARK(K, V, select_probability::p025) \ + CREATE_UNIQUE_BY_KEY_BENCHMARK(K, V, select_probability::p050) \ CREATE_UNIQUE_BY_KEY_BENCHMARK(K, V, select_probability::p075) int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 10, 5); + +#ifndef BENCHMARK_CONFIG_TUNING using custom_double2 = common::custom_type; using custom_int_double = common::custom_type; using huge_float2 = common::custom_huge_type<1024, float, float>; - BENCHMARK_SELECT_FLAG_TYPE(int, unsigned char); - BENCHMARK_SELECT_FLAG_TYPE(float, unsigned char); - BENCHMARK_SELECT_FLAG_TYPE(double, unsigned char); - BENCHMARK_SELECT_FLAG_TYPE(uint8_t, uint8_t); - BENCHMARK_SELECT_FLAG_TYPE(int8_t, int8_t); - BENCHMARK_SELECT_FLAG_TYPE(rocprim::half, int8_t); - BENCHMARK_SELECT_FLAG_TYPE(custom_double2, unsigned char); - BENCHMARK_SELECT_FLAG_TYPE(rocprim::int128_t, unsigned char); - BENCHMARK_SELECT_FLAG_TYPE(rocprim::uint128_t, unsigned char); - BENCHMARK_SELECT_FLAG_TYPE(huge_float2, unsigned char); - - BENCHMARK_SELECT_PREDICATE_TYPE(int); - BENCHMARK_SELECT_PREDICATE_TYPE(float); - BENCHMARK_SELECT_PREDICATE_TYPE(double); - BENCHMARK_SELECT_PREDICATE_TYPE(uint8_t); - BENCHMARK_SELECT_PREDICATE_TYPE(int8_t); - BENCHMARK_SELECT_PREDICATE_TYPE(rocprim::half); - BENCHMARK_SELECT_PREDICATE_TYPE(custom_int_double); - BENCHMARK_SELECT_PREDICATE_TYPE(rocprim::int128_t); - BENCHMARK_SELECT_PREDICATE_TYPE(rocprim::uint128_t); - BENCHMARK_SELECT_PREDICATE_TYPE(huge_float2); - - BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(int, unsigned char); - BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(float, unsigned char); - BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(double, unsigned char); - BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(uint8_t, uint8_t); - BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(int8_t, int8_t); - BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(rocprim::half, int8_t); - BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(custom_double2, unsigned char); - BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(rocprim::int128_t, unsigned char); - BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(rocprim::uint128_t, unsigned char); - BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(huge_float2, unsigned char); - - BENCHMARK_UNIQUE_TYPE(int); - BENCHMARK_UNIQUE_TYPE(float); - BENCHMARK_UNIQUE_TYPE(double); - BENCHMARK_UNIQUE_TYPE(uint8_t); - BENCHMARK_UNIQUE_TYPE(int8_t); - BENCHMARK_UNIQUE_TYPE(rocprim::half); - BENCHMARK_UNIQUE_TYPE(custom_int_double); - BENCHMARK_UNIQUE_TYPE(rocprim::int128_t); - BENCHMARK_UNIQUE_TYPE(rocprim::uint128_t); - BENCHMARK_UNIQUE_TYPE(huge_float2); - - BENCHMARK_UNIQUE_BY_KEY_TYPE(int, int); - BENCHMARK_UNIQUE_BY_KEY_TYPE(float, double); - BENCHMARK_UNIQUE_BY_KEY_TYPE(double, custom_double2); - BENCHMARK_UNIQUE_BY_KEY_TYPE(uint8_t, uint8_t); - BENCHMARK_UNIQUE_BY_KEY_TYPE(int8_t, double); - BENCHMARK_UNIQUE_BY_KEY_TYPE(rocprim::half, rocprim::half); - BENCHMARK_UNIQUE_BY_KEY_TYPE(custom_int_double, custom_int_double); - BENCHMARK_UNIQUE_BY_KEY_TYPE(rocprim::int128_t, rocprim::int128_t); - BENCHMARK_UNIQUE_BY_KEY_TYPE(rocprim::uint128_t, rocprim::int128_t); - BENCHMARK_UNIQUE_BY_KEY_TYPE(huge_float2, huge_float2); + BENCHMARK_SELECT_FLAG_TYPE(int, unsigned char) + BENCHMARK_SELECT_FLAG_TYPE(float, unsigned char) + BENCHMARK_SELECT_FLAG_TYPE(double, unsigned char) + BENCHMARK_SELECT_FLAG_TYPE(uint8_t, uint8_t) + BENCHMARK_SELECT_FLAG_TYPE(int8_t, int8_t) + BENCHMARK_SELECT_FLAG_TYPE(rocprim::half, int8_t) + BENCHMARK_SELECT_FLAG_TYPE(custom_double2, unsigned char) + BENCHMARK_SELECT_FLAG_TYPE(rocprim::int128_t, unsigned char) + BENCHMARK_SELECT_FLAG_TYPE(rocprim::uint128_t, unsigned char) + BENCHMARK_SELECT_FLAG_TYPE(huge_float2, unsigned char) + + BENCHMARK_SELECT_PREDICATE_TYPE(int) + BENCHMARK_SELECT_PREDICATE_TYPE(float) + BENCHMARK_SELECT_PREDICATE_TYPE(double) + BENCHMARK_SELECT_PREDICATE_TYPE(uint8_t) + BENCHMARK_SELECT_PREDICATE_TYPE(int8_t) + BENCHMARK_SELECT_PREDICATE_TYPE(rocprim::half) + BENCHMARK_SELECT_PREDICATE_TYPE(custom_int_double) + BENCHMARK_SELECT_PREDICATE_TYPE(rocprim::int128_t) + BENCHMARK_SELECT_PREDICATE_TYPE(rocprim::uint128_t) + BENCHMARK_SELECT_PREDICATE_TYPE(huge_float2) + + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(int, unsigned char) + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(float, unsigned char) + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(double, unsigned char) + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(uint8_t, uint8_t) + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(int8_t, int8_t) + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(rocprim::half, int8_t) + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(custom_double2, unsigned char) + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(rocprim::int128_t, unsigned char) + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(rocprim::uint128_t, unsigned char) + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(huge_float2, unsigned char) + + BENCHMARK_UNIQUE_TYPE(int) + BENCHMARK_UNIQUE_TYPE(float) + BENCHMARK_UNIQUE_TYPE(double) + BENCHMARK_UNIQUE_TYPE(uint8_t) + BENCHMARK_UNIQUE_TYPE(int8_t) + BENCHMARK_UNIQUE_TYPE(rocprim::half) + BENCHMARK_UNIQUE_TYPE(custom_int_double) + BENCHMARK_UNIQUE_TYPE(rocprim::int128_t) + BENCHMARK_UNIQUE_TYPE(rocprim::uint128_t) + BENCHMARK_UNIQUE_TYPE(huge_float2) + + BENCHMARK_UNIQUE_BY_KEY_TYPE(int, int) + BENCHMARK_UNIQUE_BY_KEY_TYPE(float, double) + BENCHMARK_UNIQUE_BY_KEY_TYPE(double, custom_double2) + BENCHMARK_UNIQUE_BY_KEY_TYPE(uint8_t, uint8_t) + BENCHMARK_UNIQUE_BY_KEY_TYPE(int8_t, double) + BENCHMARK_UNIQUE_BY_KEY_TYPE(rocprim::half, rocprim::half) + BENCHMARK_UNIQUE_BY_KEY_TYPE(custom_int_double, custom_int_double) + BENCHMARK_UNIQUE_BY_KEY_TYPE(rocprim::int128_t, rocprim::int128_t) + BENCHMARK_UNIQUE_BY_KEY_TYPE(rocprim::uint128_t, rocprim::int128_t) + BENCHMARK_UNIQUE_BY_KEY_TYPE(huge_float2, huge_float2) #endif - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_select.parallel.cpp.in b/benchmark/benchmark_device_select.parallel.cpp.in index c9c505209..adefe7bd7 100644 --- a/benchmark/benchmark_device_select.parallel.cpp.in +++ b/benchmark/benchmark_device_select.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -29,6 +29,6 @@ namespace { -auto benchmarks = config_autotune_register::create_bulk( +auto unused = benchmark_utils::executor::queue_autotune( device_select_benchmark_generator<@KeyType@, @ValueType@, @BlockSize@>::create); } // namespace diff --git a/benchmark/benchmark_device_select.parallel.hpp b/benchmark/benchmark_device_select.parallel.hpp index e6dfd953c..5c6af0c6d 100644 --- a/benchmark/benchmark_device_select.parallel.hpp +++ b/benchmark/benchmark_device_select.parallel.hpp @@ -24,9 +24,11 @@ #define ROCPRIM_BENCHMARK_DEVICE_SELECT_PARALLEL_HPP_ #include "benchmark_utils.hpp" -#include "cmdparser.hpp" #include "../common/utils_data_generation.hpp" +#include "../common/utils_device_ptr.hpp" + +#include "cmdparser.hpp" #include @@ -84,14 +86,11 @@ inline const char* get_probability_name(select_probability probability) return "invalid"; } -constexpr int warmup_iter = 5; -constexpr int batch_size = 10; - template -struct device_select_flag_benchmark : public config_autotune_interface +struct device_select_flag_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -103,11 +102,12 @@ struct device_select_flag_benchmark : public config_autotune_interface + ",cfg:" + partition_config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(DataType); @@ -132,35 +132,20 @@ struct device_select_flag_benchmark : public config_autotune_interface flags_0 = get_random_data01(size, get_probability(Probability), seed.get_1()); } - DataType* d_input{}; - HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); - FlagType* d_flags_0{}; - FlagType* d_flags_1{}; - FlagType* d_flags_2{}; - HIP_CHECK(hipMalloc(&d_flags_0, size * sizeof(*d_flags_0))); - HIP_CHECK( - hipMemcpy(d_flags_0, flags_0.data(), size * sizeof(*d_flags_0), hipMemcpyHostToDevice)); + common::device_ptr d_flags_0(flags_0); + common::device_ptr d_flags_1; + common::device_ptr d_flags_2; if(is_tuning) { - HIP_CHECK(hipMalloc(&d_flags_1, size * sizeof(*d_flags_1))); - HIP_CHECK(hipMemcpy(d_flags_1, - flags_1.data(), - size * sizeof(*d_flags_1), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMalloc(&d_flags_2, size * sizeof(*d_flags_2))); - HIP_CHECK(hipMemcpy(d_flags_2, - flags_2.data(), - size * sizeof(*d_flags_2), - hipMemcpyHostToDevice)); + d_flags_1.store(flags_1); + d_flags_2.store(flags_2); } - DataType* d_output{}; - HIP_CHECK(hipMalloc(&d_output, size * sizeof(*d_output))); + common::device_ptr d_output(size); - unsigned int* d_selected_count_output{}; - HIP_CHECK(hipMalloc(&d_selected_count_output, sizeof(*d_selected_count_output))); + common::device_ptr d_selected_count_output(1); const auto dispatch = [&](void* d_temp_storage, size_t& temp_storage_size_bytes) { @@ -168,69 +153,30 @@ struct device_select_flag_benchmark : public config_autotune_interface { HIP_CHECK(rocprim::select(d_temp_storage, temp_storage_size_bytes, - d_input, + d_input.get(), d_flags, - d_output, - d_selected_count_output, + d_output.get(), + d_selected_count_output.get(), size, stream)); }; - dispatch_flags(d_flags_0); + dispatch_flags(d_flags_0.get()); if(is_tuning) { - dispatch_flags(d_flags_1); - dispatch_flags(d_flags_2); + dispatch_flags(d_flags_1.get()); + dispatch_flags(d_flags_2.get()); } }; // Allocate temporary storage memory size_t temp_storage_size_bytes{}; dispatch(nullptr, temp_storage_size_bytes); - void* d_temp_storage{}; - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); - - for(int i = 0; i < warmup_iter; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipDeviceSynchronize()); - - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); + common::device_ptr d_temp_storage(temp_storage_size_bytes); - for(auto _ : state) - { - HIP_CHECK(hipEventRecord(start, stream)); - for(int i = 0; i < batch_size; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds{}; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } + state.run([&] { dispatch(d_temp_storage.get(), temp_storage_size_bytes); }); - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_input)); - if(is_tuning) - { - HIP_CHECK(hipFree(d_flags_2)); - HIP_CHECK(hipFree(d_flags_1)); - } - HIP_CHECK(hipFree(d_flags_0)); - HIP_CHECK(hipFree(d_output)); - HIP_CHECK(hipFree(d_selected_count_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(DataType)); } static constexpr bool is_tuning = Probability == select_probability::tuning; @@ -239,7 +185,7 @@ struct device_select_flag_benchmark : public config_autotune_interface template -struct device_select_predicate_benchmark : public config_autotune_interface +struct device_select_predicate_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -250,11 +196,12 @@ struct device_select_predicate_benchmark : public config_autotune_interface + ",cfg:" + partition_config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(DataType); @@ -264,15 +211,11 @@ struct device_select_predicate_benchmark : public config_autotune_interface static_cast(126), seed.get_0()); - DataType* d_input; - HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); - DataType* d_output; - HIP_CHECK(hipMalloc(&d_output, size * sizeof(*d_output))); + common::device_ptr d_output(size); - unsigned int* d_selected_count_output; - HIP_CHECK(hipMalloc(&d_selected_count_output, sizeof(*d_selected_count_output))); + common::device_ptr d_selected_count_output(1); const auto dispatch = [&](void* d_temp_storage, size_t& temp_storage_size_bytes) { @@ -282,9 +225,9 @@ struct device_select_predicate_benchmark : public config_autotune_interface { return value < static_cast(127 * probability); }; HIP_CHECK(rocprim::select(d_temp_storage, temp_storage_size_bytes, - d_input, - d_output, - d_selected_count_output, + d_input.get(), + d_output.get(), + d_selected_count_output.get(), size, predicate, stream)); @@ -304,44 +247,11 @@ struct device_select_predicate_benchmark : public config_autotune_interface size_t temp_storage_size_bytes{}; dispatch(nullptr, temp_storage_size_bytes); - void* d_temp_storage{}; - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); - - for(int i = 0; i < warmup_iter; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipDeviceSynchronize()); - - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - HIP_CHECK(hipEventRecord(start, stream)); - for(int i = 0; i < batch_size; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); + common::device_ptr d_temp_storage(temp_storage_size_bytes); - float elapsed_mseconds{}; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); - state.SetItemsProcessed(state.iterations() * batch_size * size); + state.run([&] { dispatch(d_temp_storage.get(), temp_storage_size_bytes); }); - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); - HIP_CHECK(hipFree(d_selected_count_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(DataType)); } static constexpr bool is_tuning = Probability == select_probability::tuning; @@ -351,7 +261,7 @@ template -struct device_select_predicated_flag_benchmark : public config_autotune_interface +struct device_select_predicated_flag_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -363,11 +273,12 @@ struct device_select_predicated_flag_benchmark : public config_autotune_interfac + get_probability_name(Probability) + ",cfg:" + partition_config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(DataType); @@ -392,35 +303,20 @@ struct device_select_predicated_flag_benchmark : public config_autotune_interfac flags_0 = get_random_data01(size, get_probability(Probability), seed.get_1()); } - DataType* d_input{}; - HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); + common::device_ptr d_input(input); - FlagType* d_flags_0{}; - FlagType* d_flags_1{}; - FlagType* d_flags_2{}; - HIP_CHECK(hipMalloc(&d_flags_0, size * sizeof(*d_flags_0))); - HIP_CHECK( - hipMemcpy(d_flags_0, flags_0.data(), size * sizeof(*d_flags_0), hipMemcpyHostToDevice)); + common::device_ptr d_flags_0(flags_0); + common::device_ptr d_flags_1; + common::device_ptr d_flags_2; if(is_tuning) { - HIP_CHECK(hipMalloc(&d_flags_1, size * sizeof(*d_flags_1))); - HIP_CHECK(hipMemcpy(d_flags_1, - flags_1.data(), - size * sizeof(*d_flags_1), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMalloc(&d_flags_2, size * sizeof(*d_flags_2))); - HIP_CHECK(hipMemcpy(d_flags_2, - flags_2.data(), - size * sizeof(*d_flags_2), - hipMemcpyHostToDevice)); + d_flags_1.store(flags_1); + d_flags_2.store(flags_2); } - DataType* d_output{}; - HIP_CHECK(hipMalloc(&d_output, size * sizeof(*d_output))); + common::device_ptr d_output(size); - unsigned int* d_selected_count_output{}; - HIP_CHECK(hipMalloc(&d_selected_count_output, sizeof(*d_selected_count_output))); + common::device_ptr d_selected_count_output(1); const auto dispatch = [&](void* d_temp_storage, size_t& temp_storage_size_bytes) { @@ -429,70 +325,31 @@ struct device_select_predicated_flag_benchmark : public config_autotune_interfac auto predicate = [](const FlagType& value) -> bool { return value; }; HIP_CHECK(rocprim::select(d_temp_storage, temp_storage_size_bytes, - d_input, + d_input.get(), d_flags, - d_output, - d_selected_count_output, + d_output.get(), + d_selected_count_output.get(), size, predicate, stream)); }; - dispatch_predicated_flags(d_flags_0); + dispatch_predicated_flags(d_flags_0.get()); if(is_tuning) { - dispatch_predicated_flags(d_flags_1); - dispatch_predicated_flags(d_flags_2); + dispatch_predicated_flags(d_flags_1.get()); + dispatch_predicated_flags(d_flags_2.get()); } }; // Allocate temporary storage memory size_t temp_storage_size_bytes{}; dispatch(nullptr, temp_storage_size_bytes); - void* d_temp_storage{}; - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); + common::device_ptr d_temp_storage(temp_storage_size_bytes); - for(int i = 0; i < warmup_iter; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipDeviceSynchronize()); - - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); + state.run([&] { dispatch(d_temp_storage.get(), temp_storage_size_bytes); }); - for(auto _ : state) - { - HIP_CHECK(hipEventRecord(start, stream)); - for(int i = 0; i < batch_size; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds{}; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_input)); - if(is_tuning) - { - HIP_CHECK(hipFree(d_flags_2)); - HIP_CHECK(hipFree(d_flags_1)); - } - HIP_CHECK(hipFree(d_flags_0)); - HIP_CHECK(hipFree(d_output)); - HIP_CHECK(hipFree(d_selected_count_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(DataType)); } static constexpr bool is_tuning = Probability == select_probability::tuning; @@ -520,7 +377,7 @@ inline std::vector get_unique_input(size_t size, float probability, un template -struct device_select_unique_benchmark : public config_autotune_interface +struct device_select_unique_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -531,11 +388,12 @@ struct device_select_unique_benchmark : public config_autotune_interface + ",cfg:" + partition_config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(DataType); @@ -554,31 +412,18 @@ struct device_select_unique_benchmark : public config_autotune_interface input_0 = get_unique_input(size, get_probability(Probability), seed.get_0()); } - DataType* d_input_0{}; - DataType* d_input_1{}; - DataType* d_input_2{}; - HIP_CHECK(hipMalloc(&d_input_0, size * sizeof(*d_input_0))); - HIP_CHECK( - hipMemcpy(d_input_0, input_0.data(), size * sizeof(*d_input_0), hipMemcpyHostToDevice)); + common::device_ptr d_input_0(input_0); + common::device_ptr d_input_1; + common::device_ptr d_input_2; if(is_tuning) { - HIP_CHECK(hipMalloc(&d_input_1, size * sizeof(*d_input_1))); - HIP_CHECK(hipMemcpy(d_input_1, - input_1.data(), - size * sizeof(*d_input_1), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMalloc(&d_input_2, size * sizeof(*d_input_2))); - HIP_CHECK(hipMemcpy(d_input_2, - input_2.data(), - size * sizeof(*d_input_2), - hipMemcpyHostToDevice)); + d_input_1.store(input_1); + d_input_2.store(input_2); } - DataType* d_output{}; - HIP_CHECK(hipMalloc(&d_output, size * sizeof(*d_output))); + common::device_ptr d_output(size); - unsigned int* d_selected_count_output{}; - HIP_CHECK(hipMalloc(&d_selected_count_output, sizeof(*d_selected_count_output))); + common::device_ptr d_selected_count_output(1); const auto dispatch = [&](void* d_temp_storage, size_t& temp_storage_size_bytes) { @@ -587,67 +432,29 @@ struct device_select_unique_benchmark : public config_autotune_interface HIP_CHECK(rocprim::unique(d_temp_storage, temp_storage_size_bytes, d_input, - d_output, - d_selected_count_output, + d_output.get(), + d_selected_count_output.get(), size, rocprim::equal_to(), stream)); }; - dispatch_flags(d_input_0); + dispatch_flags(d_input_0.get()); if(is_tuning) { - dispatch_flags(d_input_1); - dispatch_flags(d_input_2); + dispatch_flags(d_input_1.get()); + dispatch_flags(d_input_2.get()); } }; // Allocate temporary storage memory size_t temp_storage_size_bytes{}; dispatch(nullptr, temp_storage_size_bytes); - void* d_temp_storage{}; - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); - - for(int i = 0; i < warmup_iter; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipDeviceSynchronize()); - - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - HIP_CHECK(hipEventRecord(start, stream)); - for(int i = 0; i < batch_size; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds{}; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } + common::device_ptr d_temp_storage(temp_storage_size_bytes); - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); + state.run([&] { dispatch(d_temp_storage.get(), temp_storage_size_bytes); }); - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - if(is_tuning) - { - HIP_CHECK(hipFree(d_input_2)); - HIP_CHECK(hipFree(d_input_1)); - } - HIP_CHECK(hipFree(d_input_0)); - HIP_CHECK(hipFree(d_output)); - HIP_CHECK(hipFree(d_selected_count_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(DataType)); } static constexpr bool is_tuning = Probability == select_probability::tuning; @@ -657,7 +464,7 @@ template -struct device_select_unique_by_key_benchmark : public config_autotune_interface +struct device_select_unique_by_key_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { @@ -669,11 +476,12 @@ struct device_select_unique_by_key_benchmark : public config_autotune_interface + ",cfg:" + partition_config_name() + "}"); } - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(KeyType); @@ -700,43 +508,22 @@ struct device_select_unique_by_key_benchmark : public config_autotune_interface random_range.second, seed.get_1()); - KeyType* d_keys_input_0{}; - KeyType* d_keys_input_1{}; - KeyType* d_keys_input_2{}; - HIP_CHECK(hipMalloc(&d_keys_input_0, size * sizeof(*d_keys_input_0))); - HIP_CHECK(hipMemcpy(d_keys_input_0, - input_keys_0.data(), - size * sizeof(*d_keys_input_0), - hipMemcpyHostToDevice)); + common::device_ptr d_keys_input_0(input_keys_0); + common::device_ptr d_keys_input_1; + common::device_ptr d_keys_input_2; if(is_tuning) { - HIP_CHECK(hipMalloc(&d_keys_input_1, size * sizeof(*d_keys_input_1))); - HIP_CHECK(hipMemcpy(d_keys_input_1, - input_keys_1.data(), - size * sizeof(*d_keys_input_1), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMalloc(&d_keys_input_2, size * sizeof(*d_keys_input_2))); - HIP_CHECK(hipMemcpy(d_keys_input_2, - input_keys_2.data(), - size * sizeof(*d_keys_input_2), - hipMemcpyHostToDevice)); + d_keys_input_1.store(input_keys_1); + d_keys_input_2.store(input_keys_2); } - ValueType* d_values_input{}; - HIP_CHECK(hipMalloc(&d_values_input, size * sizeof(*d_values_input))); - HIP_CHECK(hipMemcpy(d_values_input, - input_values.data(), - size * sizeof(*d_values_input), - hipMemcpyHostToDevice)); + common::device_ptr d_values_input(input_values); - KeyType* d_keys_output{}; - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(*d_keys_output))); + common::device_ptr d_keys_output(size); - ValueType* d_values_output{}; - HIP_CHECK(hipMalloc(&d_values_output, size * sizeof(*d_values_output))); + common::device_ptr d_values_output(size); - unsigned int* d_selected_count_output{}; - HIP_CHECK(hipMalloc(&d_selected_count_output, sizeof(*d_selected_count_output))); + common::device_ptr d_selected_count_output(1); const auto dispatch = [&](void* d_temp_storage, size_t& temp_storage_size_bytes) { @@ -745,72 +532,31 @@ struct device_select_unique_by_key_benchmark : public config_autotune_interface HIP_CHECK(rocprim::unique_by_key(d_temp_storage, temp_storage_size_bytes, d_keys_input, - d_values_input, - d_keys_output, - d_values_output, - d_selected_count_output, + d_values_input.get(), + d_keys_output.get(), + d_values_output.get(), + d_selected_count_output.get(), size, rocprim::equal_to(), stream)); }; - dispatch_flags(d_keys_input_0); + dispatch_flags(d_keys_input_0.get()); if(is_tuning) { - dispatch_flags(d_keys_input_1); - dispatch_flags(d_keys_input_2); + dispatch_flags(d_keys_input_1.get()); + dispatch_flags(d_keys_input_2.get()); } }; // Allocate temporary storage memory size_t temp_storage_size_bytes{}; dispatch(nullptr, temp_storage_size_bytes); - void* d_temp_storage{}; - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); + common::device_ptr d_temp_storage(temp_storage_size_bytes); - for(int i = 0; i < warmup_iter; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipDeviceSynchronize()); - - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); + state.run([&] { dispatch(d_temp_storage.get(), temp_storage_size_bytes); }); - for(auto _ : state) - { - HIP_CHECK(hipEventRecord(start, stream)); - for(int i = 0; i < batch_size; ++i) - { - dispatch(d_temp_storage, temp_storage_size_bytes); - } - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds{}; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size - * (sizeof(KeyType) + sizeof(ValueType))); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - if(is_tuning) - { - HIP_CHECK(hipFree(d_keys_input_2)); - HIP_CHECK(hipFree(d_keys_input_1)); - } - HIP_CHECK(hipFree(d_keys_input_0)); - HIP_CHECK(hipFree(d_values_input)); - HIP_CHECK(hipFree(d_keys_output)); - HIP_CHECK(hipFree(d_values_output)); - HIP_CHECK(hipFree(d_selected_count_output)); - HIP_CHECK(hipFree(d_temp_storage)); + state.set_throughput(size, sizeof(KeyType) + sizeof(ValueType)); } static constexpr bool is_tuning = Probability == select_probability::tuning; @@ -828,7 +574,7 @@ struct create_benchmark static constexpr unsigned int max_items_per_thread = max_shared_memory / (block_size * max_size_per_element); - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { storage.emplace_back( std::make_unique>()); @@ -845,7 +591,7 @@ struct create_benchmark template struct create_benchmark { - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { storage.emplace_back(std::make_unique>()); storage.emplace_back( @@ -860,14 +606,14 @@ struct device_select_benchmark_generator template struct create_ipt { - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { using config = rocprim::select_config; create_benchmark{}(storage); } }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { static constexpr int max_items_per_thread = std::min(64 / std::max(sizeof(KeyType), sizeof(ValueType)), size_t{32}); diff --git a/benchmark/benchmark_device_transform.cpp b/benchmark/benchmark_device_transform.cpp index e4362c83e..26095987b 100644 --- a/benchmark/benchmark_device_transform.cpp +++ b/benchmark/benchmark_device_transform.cpp @@ -23,16 +23,10 @@ #include "benchmark_device_transform.parallel.hpp" #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" - #ifndef BENCHMARK_CONFIG_TUNING #include "../common/utils_custom_type.hpp" #endif -// Google Benchmark -#include - // HIP API #include @@ -48,67 +42,16 @@ #include #endif -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; -#endif +#define CREATE_BENCHMARK(T) executor.queue_instance(device_transform_benchmark()); -#define CREATE_BENCHMARK(T) \ - { \ - const device_transform_benchmark instance{}; \ - REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ - } +#define CREATE_BENCHMARK_BINARY(T) \ + executor.queue_instance(device_transform_benchmark()); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); -#ifdef BENCHMARK_CONFIG_TUNING - // optionally run an evenly split subset of benchmarks, when making multiple program invocations - parser.set_optional("parallel_instance", - "parallel_instance", - 0, - "parallel instance index"); - parser.set_optional("parallel_instances", - "parallel_instances", - 1, - "total parallel instances"); -#endif // BENCHMARK_CONFIG_TUNING - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default + benchmark_utils::executor executor(argc, argv, 512 * benchmark_utils::MiB, 10, 5); - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks = {}; -#ifdef BENCHMARK_CONFIG_TUNING - const int parallel_instance = parser.get("parallel_instance"); - const int parallel_instances = parser.get("parallel_instances"); - config_autotune_register::register_benchmark_subset(benchmarks, - parallel_instance, - parallel_instances, - bytes, - seed, - stream); -#else // BENCHMARK_CONFIG_TUNING +#ifndef BENCHMARK_CONFIG_TUNING using custom_float2 = common::custom_type; using custom_double2 = common::custom_type; CREATE_BENCHMARK(int) @@ -126,26 +69,13 @@ int main(int argc, char* argv[]) CREATE_BENCHMARK(rocprim::int128_t) CREATE_BENCHMARK(rocprim::uint128_t) -#endif // BENCHMARK_CONFIG_TUNING - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); + CREATE_BENCHMARK_BINARY(int) + CREATE_BENCHMARK_BINARY(float) + CREATE_BENCHMARK_BINARY(int8_t) + CREATE_BENCHMARK_BINARY(rocprim::int128_t) + CREATE_BENCHMARK_BINARY(custom_double2) +#endif - return 0; + executor.run(); } diff --git a/benchmark/benchmark_device_transform.parallel.cpp.in b/benchmark/benchmark_device_transform.parallel.cpp.in index f3a6a135c..2ca96b9a1 100644 --- a/benchmark/benchmark_device_transform.parallel.cpp.in +++ b/benchmark/benchmark_device_transform.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,8 +28,10 @@ #include namespace { - auto benchmarks = config_autotune_register::create_bulk( + auto unused = benchmark_utils::executor::queue_autotune( device_transform_benchmark_generator< - @DataType@, - @BlockSize@>::create); + @DataType@, + false, + @BlockSize@, + rocprim::load_default>::create); } diff --git a/benchmark/benchmark_device_transform.parallel.hpp b/benchmark/benchmark_device_transform.parallel.hpp index 0db664568..7a0523e08 100644 --- a/benchmark/benchmark_device_transform.parallel.hpp +++ b/benchmark/benchmark_device_transform.parallel.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,8 @@ #include "benchmark_utils.hpp" +#include "../common/utils_device_ptr.hpp" + // Google Benchmark #include @@ -47,7 +49,8 @@ std::string transform_config_name() { auto config = Config(); return "{bs:" + std::to_string(config.block_size) - + ",ipt:" + std::to_string(config.items_per_thread) + "}"; + + ",ipt:" + std::to_string(config.items_per_thread) + + ",lt:" + get_thread_load_method_name(config.load_type) + "}"; } template<> @@ -56,27 +59,29 @@ inline std::string transform_config_name() return "default_config"; } -template -struct device_transform_benchmark : public config_autotune_interface +template +struct device_transform_benchmark : public benchmark_utils::autotune_interface { std::string name() const override { using namespace std::string_literals; - return bench_naming::format_name("{lvl:device,algo:transform,value_type:" - + std::string(Traits::name()) - + ",cfg:" + transform_config_name() + "}"); + return bench_naming::format_name( + "{lvl:device,algo:transform" + std::string(IsPointer ? "_pointer" : "") + + ",op:" + std::string(IsBinary ? "binary" : "unary") + ",value_type:" + + std::string(Traits::name()) + ",cfg:" + transform_config_name() + "}"); } - static constexpr unsigned int batch_size = 10; - static constexpr unsigned int warmup_size = 5; - - void run(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) const override + void run(benchmark_utils::state&& state) override { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using output_type = T; // Calculate the number of elements @@ -89,88 +94,68 @@ struct device_transform_benchmark : public config_autotune_interface const std::vector input = get_random_data(size, random_range.first, random_range.second, seed.get_0()); - T* d_input; - output_type* d_output = nullptr; - HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(input[0]))); - HIP_CHECK(hipMemcpy(d_input, - input.data(), - input.size() * sizeof(input[0]), - hipMemcpyHostToDevice)); - - HIP_CHECK(hipMalloc(&d_output, size * sizeof(output_type))); + common::device_ptr d_input(input); + common::device_ptr d_output(size); - const auto launch = [&] + if constexpr(IsBinary) { - auto transform_op = [](T v) { return v + T(5); }; - return rocprim::transform(d_input, - d_output, - size, - transform_op, - stream, - debug_synchronous); - }; - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - HIP_CHECK(launch()); - } - HIP_CHECK(hipDeviceSynchronize()); + const std::vector input2 + = get_random_data(size, random_range.first, random_range.second, seed.get_0()); + common::device_ptr d_input2(input2); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - // Run - for(auto _ : state) + // If it is not a unary operator, it can not make use of the pointer optimization. + const auto launch = [&] + { + auto transform_op = [](T v1, T v2) { return v1 + v2; }; + return rocprim::transform(rocprim::tuple(d_input.get(), d_input2.get()), + d_output.get(), + size, + transform_op, + stream, + debug_synchronous); + }; + + state.run([&] { HIP_CHECK(launch()); }); + state.set_throughput(size, sizeof(T) + sizeof(T)); + } + else { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) + const auto launch = [&] { - HIP_CHECK(launch()); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); + auto transform_op = [](T v) { return v + T(5); }; + return rocprim::detail::transform_impl(d_input.get(), + d_output.get(), + size, + transform_op, + stream, + debug_synchronous); + }; + + state.run([&] { HIP_CHECK(launch()); }); + state.set_throughput(size, sizeof(T)); } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); } }; -template +template struct device_transform_benchmark_generator { template struct create_ipt { - using generated_config = rocprim::transform_config; + using generated_config = rocprim:: + transform_config; - void operator()(std::vector>& storage) + void operator()(std::vector>& storage) { storage.emplace_back( - std::make_unique>()); + std::make_unique< + device_transform_benchmark>()); } }; - static void create(std::vector>& storage) + static void create(std::vector>& storage) { static constexpr unsigned int min_items_per_thread = 0; static constexpr unsigned int max_items_per_thread = rocprim::Log2<16>::VALUE; diff --git a/benchmark/benchmark_device_transform_pointer.cpp b/benchmark/benchmark_device_transform_pointer.cpp new file mode 100644 index 000000000..804ea630e --- /dev/null +++ b/benchmark/benchmark_device_transform_pointer.cpp @@ -0,0 +1,72 @@ +// MIT License +// +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "benchmark_device_transform.parallel.hpp" +#include "benchmark_utils.hpp" + +#ifndef BENCHMARK_CONFIG_TUNING + #include "../common/utils_custom_type.hpp" +#endif + +// HIP API +#include + +// rocPRIM +#ifndef BENCHMARK_CONFIG_TUNING + #include +#endif + +#include +#include +#include +#ifndef BENCHMARK_CONFIG_TUNING + #include +#endif + +#define CREATE_BENCHMARK(T) executor.queue_instance(device_transform_benchmark()); + +int main(int argc, char* argv[]) +{ + benchmark_utils::executor executor(argc, argv, 512 * benchmark_utils::MiB, 10, 5); + +#ifndef BENCHMARK_CONFIG_TUNING + using custom_float2 = common::custom_type; + using custom_double2 = common::custom_type; + CREATE_BENCHMARK(int) + CREATE_BENCHMARK(long long) + + CREATE_BENCHMARK(int8_t) + CREATE_BENCHMARK(uint8_t) + CREATE_BENCHMARK(rocprim::half) + + CREATE_BENCHMARK(float) + CREATE_BENCHMARK(double) + + CREATE_BENCHMARK(custom_float2) + CREATE_BENCHMARK(custom_double2) + + CREATE_BENCHMARK(rocprim::int128_t) + CREATE_BENCHMARK(rocprim::uint128_t) +#endif + + executor.run(); +} diff --git a/rocprim/include/rocprim/detail/radix_sort.hpp b/benchmark/benchmark_device_transform_pointer.parallel.cpp.in similarity index 59% rename from rocprim/include/rocprim/detail/radix_sort.hpp rename to benchmark/benchmark_device_transform_pointer.parallel.cpp.in index 2be7221b8..875d90db2 100644 --- a/rocprim/include/rocprim/detail/radix_sort.hpp +++ b/benchmark/benchmark_device_transform_pointer.parallel.cpp.in @@ -1,4 +1,6 @@ -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// MIT License +// +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -7,22 +9,29 @@ // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "benchmark_utils.hpp" +#include "benchmark_device_transform.parallel.hpp" -#ifndef ROCPRIM_DETAIL_RADIX_SORT_HPP_ -#define ROCPRIM_DETAIL_RADIX_SORT_HPP_ +#include -ROCPRIM_PRAGMA_MESSAGE("Functionality from rocprim/detail/radix_sort.hpp has been moved to " - "rocprim/thread/radix_key_codec.hpp.") -#include "../thread/radix_key_codec.hpp" +#include -#endif // ROCPRIM_DETAIL_RADIX_SORT_HPP_ +namespace { + auto unused = benchmark_utils::executor::queue_autotune( + device_transform_benchmark_generator< + @DataType@, + true, + @BlockSize@, + @LoadType@>::create); +} diff --git a/benchmark/benchmark_predicate_iterator.cpp b/benchmark/benchmark_predicate_iterator.cpp index 44a239e37..ec88cc6dd 100644 --- a/benchmark/benchmark_predicate_iterator.cpp +++ b/benchmark/benchmark_predicate_iterator.cpp @@ -21,10 +21,10 @@ // SOFTWARE. #include "benchmark_utils.hpp" -#include "cmdparser.hpp" #include "../common/predicate_iterator.hpp" #include "../common/utils_custom_type.hpp" +#include "../common/utils_device_ptr.hpp" #include @@ -42,13 +42,6 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; -#endif - -const unsigned int batch_size = 10; -const unsigned int warmup_size = 5; - template struct less_than { @@ -109,11 +102,12 @@ struct write_predicate_it }; template -void run_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +void run_benchmark(benchmark_utils::state&& state) { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using T = typename IteratorBenchmark::value_type; // Calculate the number of elements @@ -122,141 +116,53 @@ void run_benchmark(benchmark::State& state, const auto random_range = limit_random_range(0, 99); std::vector input = get_random_data(size, random_range.first, random_range.second, seed.get_0()); - T* d_input; - T* d_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(T))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), size * sizeof(T))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; ++i) - { - IteratorBenchmark{}(d_input, d_output, size, stream); - } + common::device_ptr d_input(input); + common::device_ptr d_output(size); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; ++i) - { - IteratorBenchmark{}(d_input, d_output, size, stream); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } + state.run([&] { IteratorBenchmark{}(d_input.get(), d_output.get(), size, stream); }); - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + state.set_throughput(size, sizeof(T)); } -#define CREATE_BENCHMARK(B, T, C) \ - benchmark::RegisterBenchmark(bench_naming::format_name("{lvl:device,algo:" #B ",p:p" #C \ - ",key_type:" #T ",cfg:default_config}") \ - .c_str(), \ - run_benchmark, common::increment_by<5>>>, \ - bytes, \ - seed, \ - stream) +#define CREATE_BENCHMARK(B, T, C) \ + executor.queue_fn(bench_naming::format_name("{lvl:device,algo:" #B ",p:p" #C ",key_type:" #T \ + ",cfg:default_config}") \ + .c_str(), \ + run_benchmark, common::increment_by<5>>>); // clang-format off #define CREATE_TYPED_BENCHMARK(T) \ - CREATE_BENCHMARK(transform_it, T, 0), \ - CREATE_BENCHMARK(read_predicate_it, T, 0), \ - CREATE_BENCHMARK(write_predicate_it, T, 0), \ - CREATE_BENCHMARK(transform_it, T, 25), \ - CREATE_BENCHMARK(read_predicate_it, T, 25), \ - CREATE_BENCHMARK(write_predicate_it, T, 25), \ - CREATE_BENCHMARK(transform_it, T, 50), \ - CREATE_BENCHMARK(read_predicate_it, T, 50), \ - CREATE_BENCHMARK(write_predicate_it, T, 50), \ - CREATE_BENCHMARK(transform_it, T, 75), \ - CREATE_BENCHMARK(read_predicate_it, T, 75), \ - CREATE_BENCHMARK(write_predicate_it, T, 75), \ - CREATE_BENCHMARK(transform_it, T, 100), \ - CREATE_BENCHMARK(read_predicate_it, T, 100), \ + CREATE_BENCHMARK(transform_it, T, 0) \ + CREATE_BENCHMARK(read_predicate_it, T, 0) \ + CREATE_BENCHMARK(write_predicate_it, T, 0) \ + CREATE_BENCHMARK(transform_it, T, 25) \ + CREATE_BENCHMARK(read_predicate_it, T, 25) \ + CREATE_BENCHMARK(write_predicate_it, T, 25) \ + CREATE_BENCHMARK(transform_it, T, 50) \ + CREATE_BENCHMARK(read_predicate_it, T, 50) \ + CREATE_BENCHMARK(write_predicate_it, T, 50) \ + CREATE_BENCHMARK(transform_it, T, 75) \ + CREATE_BENCHMARK(read_predicate_it, T, 75) \ + CREATE_BENCHMARK(write_predicate_it, T, 75) \ + CREATE_BENCHMARK(transform_it, T, 100) \ + CREATE_BENCHMARK(read_predicate_it, T, 100) \ CREATE_BENCHMARK(write_predicate_it, T, 100) // clang-format on int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); + benchmark_utils::executor executor(argc, argv, 512 * benchmark_utils::MiB, 10, 5); using custom_128 = common::custom_type; - // Add benchmarks - std::vector benchmarks - = {CREATE_TYPED_BENCHMARK(int8_t), - CREATE_TYPED_BENCHMARK(int16_t), - CREATE_TYPED_BENCHMARK(int32_t), - CREATE_TYPED_BENCHMARK(int64_t), - CREATE_TYPED_BENCHMARK(custom_128), - CREATE_TYPED_BENCHMARK(rocprim::int128_t), - CREATE_TYPED_BENCHMARK(rocprim::uint128_t)}; - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); + CREATE_TYPED_BENCHMARK(int8_t) + CREATE_TYPED_BENCHMARK(int16_t) + CREATE_TYPED_BENCHMARK(int32_t) + CREATE_TYPED_BENCHMARK(int64_t) + CREATE_TYPED_BENCHMARK(custom_128) + CREATE_TYPED_BENCHMARK(rocprim::int128_t) + CREATE_TYPED_BENCHMARK(rocprim::uint128_t) - return 0; + executor.run(); } diff --git a/benchmark/benchmark_utils.hpp b/benchmark/benchmark_utils.hpp index 88551af06..84076fe41 100644 --- a/benchmark/benchmark_utils.hpp +++ b/benchmark/benchmark_utils.hpp @@ -37,10 +37,12 @@ #include #include #include -#include #include #include +// CmdParser +#include "cmdparser.hpp" + #include #include #include @@ -48,6 +50,7 @@ #include #include #include +#include #include #include #include @@ -84,6 +87,8 @@ class managed_seed } } + managed_seed() {} + unsigned int get_0() const { return is_random ? std::random_device{}() : seeds[0]; @@ -548,87 +553,6 @@ void static_for_each(Args&&... args) std::forward(args)...); } -#define REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance) \ - benchmark::internal::Benchmark* benchmark = benchmark::RegisterBenchmark( \ - instance.name().c_str(), \ - [instance](benchmark::State& state, \ - size_t _size, \ - const managed_seed& _seed, \ - hipStream_t _stream) { instance.run(state, _size, _seed, _stream); }, \ - size, \ - seed, \ - stream); \ - benchmarks.emplace_back(benchmark) - -struct config_autotune_interface -{ - virtual std::string name() const = 0; - virtual std::string sort_key() const - { - return name(); - }; - virtual ~config_autotune_interface() = default; - virtual void run(benchmark::State&, size_t, const managed_seed&, hipStream_t) const = 0; -}; - -struct config_autotune_register -{ - static std::vector>& vector() - { - static std::vector> storage; - return storage; - } - - template - static config_autotune_register create() - { - vector().push_back(std::make_unique()); - return config_autotune_register(); - } - - template - static config_autotune_register create_bulk(BulkCreateFunction&& f) - { - std::forward(f)(vector()); - return config_autotune_register(); - } - - // Register a subset of all created benchmarks for the current parallel instance and add to vector. - static void register_benchmark_subset(std::vector& benchmarks, - int parallel_instance_index, - int parallel_instance_count, - size_t size, - const managed_seed& seed, - hipStream_t stream) - { - std::vector>& configs = vector(); - // sorting to get a consistent order because order of initialization of static variables is undefined by the C++ standard. - std::sort(configs.begin(), - configs.end(), - [](const auto& l, const auto& r) { return l->sort_key() < r->sort_key(); }); - size_t configs_per_instance - = (configs.size() + parallel_instance_count - 1) / parallel_instance_count; - size_t start = std::min(parallel_instance_index * configs_per_instance, configs.size()); - size_t end = std::min((parallel_instance_index + 1) * configs_per_instance, configs.size()); - for(size_t i = start; i < end; ++i) - { - std::unique_ptr& uniq_ptr = configs.at(i); - config_autotune_interface* tuning_benchmark = uniq_ptr.get(); - benchmark::internal::Benchmark* benchmark = benchmark::RegisterBenchmark( - tuning_benchmark->name().c_str(), - [tuning_benchmark](benchmark::State& state, - size_t size, - const managed_seed& seed, - hipStream_t stream) - { tuning_benchmark->run(state, size, seed, stream); }, - size, - seed, - stream); - benchmarks.emplace_back(benchmark); - } - } -}; - // Inserts spaces at beginning of string if string shorter than specified length. inline std::string pad_string(std::string str, const size_t len) { @@ -994,107 +918,6 @@ inline const char* Traits::name() return "rocprim::uint128_t"; } -inline void add_common_benchmark_info() -{ - hipDeviceProp_t devProp; - int device_id = 0; - HIP_CHECK(hipGetDevice(&device_id)); - HIP_CHECK(hipGetDeviceProperties(&devProp, device_id)); - - auto str = [](const std::string& name, const std::string& val) - { benchmark::AddCustomContext(name, val); }; - - auto num = [](const std::string& name, const auto& value) - { benchmark::AddCustomContext(name, std::to_string(value)); }; - - auto dim2 = [num](const std::string& name, const auto* values) - { - num(name + "_x", values[0]); - num(name + "_y", values[1]); - }; - - auto dim3 = [num, dim2](const std::string& name, const auto* values) - { - dim2(name, values); - num(name + "_z", values[2]); - }; - - str("hdp_name", devProp.name); - num("hdp_total_global_mem", devProp.totalGlobalMem); - num("hdp_shared_mem_per_block", devProp.sharedMemPerBlock); - num("hdp_regs_per_block", devProp.regsPerBlock); - num("hdp_warp_size", devProp.warpSize); - num("hdp_max_threads_per_block", devProp.maxThreadsPerBlock); - dim3("hdp_max_threads_dim", devProp.maxThreadsDim); - dim3("hdp_max_grid_size", devProp.maxGridSize); - num("hdp_clock_rate", devProp.clockRate); - num("hdp_memory_clock_rate", devProp.memoryClockRate); - num("hdp_memory_bus_width", devProp.memoryBusWidth); - num("hdp_total_const_mem", devProp.totalConstMem); - num("hdp_major", devProp.major); - num("hdp_minor", devProp.minor); - num("hdp_multi_processor_count", devProp.multiProcessorCount); - num("hdp_l2_cache_size", devProp.l2CacheSize); - num("hdp_max_threads_per_multiprocessor", devProp.maxThreadsPerMultiProcessor); - num("hdp_compute_mode", devProp.computeMode); - num("hdp_clock_instruction_rate", devProp.clockInstructionRate); - num("hdp_concurrent_kernels", devProp.concurrentKernels); - num("hdp_pci_domain_id", devProp.pciDomainID); - num("hdp_pci_bus_id", devProp.pciBusID); - num("hdp_pci_device_id", devProp.pciDeviceID); - num("hdp_max_shared_memory_per_multi_processor", devProp.maxSharedMemoryPerMultiProcessor); - num("hdp_is_multi_gpu_board", devProp.isMultiGpuBoard); - num("hdp_can_map_host_memory", devProp.canMapHostMemory); - str("hdp_gcn_arch_name", devProp.gcnArchName); - num("hdp_integrated", devProp.integrated); - num("hdp_cooperative_launch", devProp.cooperativeLaunch); - num("hdp_cooperative_multi_device_launch", devProp.cooperativeMultiDeviceLaunch); - num("hdp_max_texture_1d_linear", devProp.maxTexture1DLinear); - num("hdp_max_texture_1d", devProp.maxTexture1D); - dim2("hdp_max_texture_2d", devProp.maxTexture2D); - dim3("hdp_max_texture_3d", devProp.maxTexture3D); - num("hdp_mem_pitch", devProp.memPitch); - num("hdp_texture_alignment", devProp.textureAlignment); - num("hdp_texture_pitch_alignment", devProp.texturePitchAlignment); - num("hdp_kernel_exec_timeout_enabled", devProp.kernelExecTimeoutEnabled); - num("hdp_ecc_enabled", devProp.ECCEnabled); - num("hdp_tcc_driver", devProp.tccDriver); - num("hdp_cooperative_multi_device_unmatched_func", devProp.cooperativeMultiDeviceUnmatchedFunc); - num("hdp_cooperative_multi_device_unmatched_grid_dim", - devProp.cooperativeMultiDeviceUnmatchedGridDim); - num("hdp_cooperative_multi_device_unmatched_block_dim", - devProp.cooperativeMultiDeviceUnmatchedBlockDim); - num("hdp_cooperative_multi_device_unmatched_shared_mem", - devProp.cooperativeMultiDeviceUnmatchedSharedMem); - num("hdp_is_large_bar", devProp.isLargeBar); - num("hdp_asic_revision", devProp.asicRevision); - num("hdp_managed_memory", devProp.managedMemory); - num("hdp_direct_managed_mem_access_from_host", devProp.directManagedMemAccessFromHost); - num("hdp_concurrent_managed_access", devProp.concurrentManagedAccess); - num("hdp_pageable_memory_access", devProp.pageableMemoryAccess); - num("hdp_pageable_memory_access_uses_host_page_tables", - devProp.pageableMemoryAccessUsesHostPageTables); - - const auto arch = devProp.arch; - num("hdp_arch_has_global_int32_atomics", arch.hasGlobalInt32Atomics); - num("hdp_arch_has_global_float_atomic_exch", arch.hasGlobalFloatAtomicExch); - num("hdp_arch_has_shared_int32_atomics", arch.hasSharedInt32Atomics); - num("hdp_arch_has_shared_float_atomic_exch", arch.hasSharedFloatAtomicExch); - num("hdp_arch_has_float_atomic_add", arch.hasFloatAtomicAdd); - num("hdp_arch_has_global_int64_atomics", arch.hasGlobalInt64Atomics); - num("hdp_arch_has_shared_int64_atomics", arch.hasSharedInt64Atomics); - num("hdp_arch_has_doubles", arch.hasDoubles); - num("hdp_arch_has_warp_vote", arch.hasWarpVote); - num("hdp_arch_has_warp_ballot", arch.hasWarpBallot); - num("hdp_arch_has_warp_shuffle", arch.hasWarpShuffle); - num("hdp_arch_has_funnel_shift", arch.hasFunnelShift); - num("hdp_arch_has_thread_fence_system", arch.hasThreadFenceSystem); - num("hdp_arch_has_sync_threads_ext", arch.hasSyncThreadsExt); - num("hdp_arch_has_surface_funcs", arch.hasSurfaceFuncs); - num("hdp_arch_has_3d_grid", arch.has3dGrid); - num("hdp_arch_has_dynamic_parallelism", arch.hasDynamicParallelism); -} - inline const char* get_block_scan_algorithm_name(rocprim::block_scan_algorithm alg) { switch(alg) @@ -1126,6 +949,22 @@ inline const char* get_block_load_method_name(rocprim::block_load_method method) return "default_method"; } +inline const char* get_thread_load_method_name(rocprim::cache_load_modifier method) +{ + switch(method) + { + case rocprim::load_default: return "load_default"; + case rocprim::load_ca: return "load_ca"; + case rocprim::load_cg: return "load_cg"; + case rocprim::load_nontemporal: return "load_nontemporal"; + case rocprim::load_cv: return "load_cv"; + case rocprim::load_ldg: return "load_ldg"; + case rocprim::load_volatile: return "load_volatile"; + case rocprim::load_count: return "load_count"; + } + return "load_default"; +} + template struct alignas(Alignment) custom_aligned_type { @@ -1146,4 +985,578 @@ inline std::string partition_config_name() return "default_config"; } +namespace benchmark_utils +{ + +constexpr size_t KiB = 1024; +constexpr size_t MiB = 1024 * KiB; +constexpr size_t GiB = 1024 * MiB; + +class state +{ +public: + state(hipStream_t stream, + size_t size, + const managed_seed& seed, + size_t batch_iterations, + benchmark::State& gbench_state, + size_t warmup_iterations, + bool cold, + bool record_as_whole) + : stream(stream) + , size(size) + , bytes(size) + , seed(seed) + , batch_iterations(batch_iterations) + , gbench_state(gbench_state) + , warmup_iterations(warmup_iterations) + , cold(cold) + , record_as_whole(record_as_whole) + , events(record_as_whole ? 2 : batch_iterations * 2) + {} + + // Used to reset the input array of algorithms like device_merge_inplace. + void run_before_every_iteration(std::function lambda) + { + run_before_every_iteration_lambda = lambda; + } + + // Used to accumulate the results of state.run() calls. + void accumulate_total_gbench_iterations_every_run() + { + reset_total_gbench_iterations_every_run = false; + } + + void run(std::function kernel) + { + for(auto& event : events) + { + HIP_CHECK(hipEventCreate(&event)); + } + + // Warm-up + for(size_t i = 0; i < warmup_iterations; ++i) + { + // Benchmarks may expect their kernel input to be prepared by this lambda, + // so to prevent any potential crashes, we call the lambda during warm-up. + if(run_before_every_iteration_lambda) + { + run_before_every_iteration_lambda(); + } + + kernel(); + } + HIP_CHECK(hipDeviceSynchronize()); + + if(run_before_every_iteration_lambda && batch_iterations > 1 && record_as_whole) + { + std::cerr << "Error: This benchmark calls run_before_every_iteration() and has a " + "batch_iterations count that is higher than 1, which means it does not " + "support using --record_as_whole.\n"; + exit(EXIT_FAILURE); + } + + // Run + for(auto _ : gbench_state) + { + if(record_as_whole) + { + if(run_before_every_iteration_lambda) + { + run_before_every_iteration_lambda(); + } + + HIP_CHECK(hipEventRecord(events[0], stream)); + for(size_t i = 0; i < batch_iterations; ++i) + { + kernel(); + } + HIP_CHECK(hipEventRecord(events[1], stream)); + HIP_CHECK(hipEventSynchronize(events[1])); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, events[0], events[1])); + times.emplace_back(elapsed_mseconds); + gbench_state.SetIterationTime(elapsed_mseconds / 1000); + } + else + { + for(size_t i = 0; i < batch_iterations; ++i) + { + if(run_before_every_iteration_lambda) + { + run_before_every_iteration_lambda(); + } + + if(cold) + { + clear_gpu_cache(stream); + } + + // Even events record the start time. + HIP_CHECK(hipEventRecord(events[i * 2], stream)); + + kernel(); + + // Odd events record the stop time. + HIP_CHECK(hipEventRecord(events[i * 2 + 1], stream)); + } + + // Wait until the last record event has completed. + HIP_CHECK(hipEventSynchronize(events[batch_iterations * 2 - 1])); + + // Accumulate the total elapsed time. + double elapsed_mseconds = 0.0; + for(size_t i = 0; i < batch_iterations; i++) + { + float iteration_mseconds; + HIP_CHECK( + hipEventElapsedTime(&iteration_mseconds, events[i * 2], events[i * 2 + 1])); + times.emplace_back(iteration_mseconds); + elapsed_mseconds += iteration_mseconds; + } + gbench_state.SetIterationTime(elapsed_mseconds / 1000); + } + } + + if(reset_total_gbench_iterations_every_run) + { + total_gbench_iterations = 0; + } + total_gbench_iterations += gbench_state.iterations(); + + for(const auto& event : events) + { + HIP_CHECK(hipEventDestroy(event)); + } + } + + void set_throughput(size_t actual_size, size_t type_size) + { + if(has_set_throughput) + { + std::cerr << "Error: Benchmarks should only ever call set_throughput() once, at the " + "very end.\n"; + exit(EXIT_FAILURE); + } + has_set_throughput = true; + + gbench_state.SetBytesProcessed(total_gbench_iterations * batch_iterations * actual_size + * type_size); + gbench_state.SetItemsProcessed(total_gbench_iterations * batch_iterations * actual_size); + + output_statistics(); + } + + hipStream_t stream; + size_t size; + size_t bytes; + managed_seed seed; + size_t batch_iterations; + benchmark::State& gbench_state; + +private: + // Zeros a 256 MiB buffer, used to clear the cache before each kernel call. + // 256 MiB is the size of the largest cache on any AMD GPU. + // It is currently not possible to fetch the L3 cache size from the runtime. + inline void clear_gpu_cache(hipStream_t stream) + { + constexpr size_t buf_size = 256 * MiB; + static void* buf = nullptr; + if(!buf) + { + HIP_CHECK(hipMalloc(&buf, buf_size)); + } + HIP_CHECK(hipMemsetAsync(buf, 0, buf_size, stream)); + } + + void output_statistics() + { + double mean = get_mean(); + double median = get_median(); + double stddev = get_stddev(mean); + double cv = get_cv(stddev, mean); + + gbench_state.counters["mean"] = mean; + gbench_state.counters["median"] = median; + gbench_state.counters["stddev"] = stddev; + gbench_state.counters["cv"] = cv; + } + + double get_mean() + { + return std::reduce(times.begin(), times.end()) / times.size(); + } + + // Technically when times.size() is even, the median is the arithmetic mean + // of the elements k=N/2 and k=N/2+1. This would be overkill here, + // as times.size() is large enough, and recorded times are similar enough. + double get_median() + { + size_t center_index = times.size() / 2; + std::nth_element(times.begin(), times.begin() + center_index, times.end()); + return times[center_index]; + } + + double get_stddev(double mean) + { + auto SumSquares = [](const std::vector& v) + { return std::transform_reduce(v.begin(), v.end(), v.begin(), 0.0); }; + auto Sqr = [](double dat) { return dat * dat; }; + auto Sqrt = [](double dat) { return dat < 0.0 ? 0.0 : std::sqrt(dat); }; + + double stddev = 0.0; + if(times.size() > 1) + { + double avg_squares = SumSquares(times) * (1.0 / times.size()); + stddev = Sqrt(times.size() / (times.size() - 1.0) * (avg_squares - Sqr(mean))); + } + return stddev; + } + + double get_cv(double stddev, double mean) + { + return times.size() >= 2 ? stddev / mean : 0.0; + } + + size_t warmup_iterations; + bool cold; + bool record_as_whole; + + std::vector events; + std::function run_before_every_iteration_lambda = nullptr; + size_t total_gbench_iterations = 0; + bool reset_total_gbench_iterations_every_run = true; + std::vector times; + bool has_set_throughput = false; +}; + +struct autotune_interface +{ + virtual std::string name() const = 0; + virtual std::string sort_key() const + { + return name(); + }; + virtual ~autotune_interface() = default; + virtual void run(state&& state) = 0; +}; + +class executor +{ +public: + executor(int argc, + char* argv[], + size_t default_bytes, + size_t default_batch_iterations, + size_t default_warmup_iterations, + bool default_cold = true, + int default_trials = -1) + { + cli::Parser parser(argc, argv); + + set_optional_parser_flags(parser, + default_bytes, + default_batch_iterations, + default_warmup_iterations, + default_cold, + default_trials); + + parser.run_and_exit_if_error(); + + benchmark::Initialize(&argc, argv); + + parse(parser); + + add_context(); + } + + template + void queue_fn(const std::string& name, T bench_fn) + { + apply_settings(benchmark::RegisterBenchmark(name.c_str(), + [=](benchmark::State& gbench_state) + { bench_fn(new_state(gbench_state)); })); + } + + template + void queue_instance(Benchmark&& instance) + { + apply_settings(benchmark::RegisterBenchmark( + instance.name().c_str(), + [=](benchmark::State& gbench_state) + { + // run() requires a mutable instance, so create a mutable copy. + // Using [&instance] doesn't work, as it creates a dangling reference at runtime. + // Marking the lambda mutable doesn't work, as the &&instance it copies is const. + Benchmark(std::move(instance)).run(new_state(gbench_state)); + })); + } + + template + static bool queue_sorted_instance() + { + sorted_benchmarks().push_back(std::make_unique()); + return true; // Must return something, as this function gets called in global scope. + } + + template + static bool queue_autotune(BulkCreateFunction&& f) + { + std::forward(f)(sorted_benchmarks()); + return true; // Must return something, as this function gets called in global scope. + } + + void run() + { + register_sorted_subset(parallel_instance, parallel_instances); + benchmark::RunSpecifiedBenchmarks(); + } + +private: + void set_optional_parser_flags(cli::Parser& parser, + size_t default_bytes, + size_t default_batch_iterations, + size_t default_warmup_iterations, + bool default_cold, + int default_trials) + { + parser.set_optional("size", "size", default_bytes, "size in bytes"); + parser.set_optional("batch_iterations", + "batch_iterations", + default_batch_iterations, + "number of batch iterations"); + parser.set_optional("warmup_iterations", + "warmup_iterations", + default_warmup_iterations, + "number of warmup iterations"); + parser.set_optional("hot", + "hot", + !default_cold, + "don't clear the gpu cache on every batch iteration"); + parser.set_optional( + "record_as_whole", + "record_as_whole", + false, + "record the batch iterations as a whole, at the very start and end, which necessitates " + "that gpu cache clearing between iterations can't be done"); + + parser.set_optional("seed", "seed", "random", get_seed_message()); + parser.set_optional("trials", "trials", default_trials, "number of iterations"); + parser.set_optional("name_format", + "name_format", + "human", + "either: json,human,txt"); + + // Optionally run an evenly split subset of benchmarks for autotuning. + parser.set_optional("parallel_instance", + "parallel_instance", + 0, + "parallel instance index"); + parser.set_optional("parallel_instances", + "parallel_instances", + 1, + "total parallel instances"); + } + + void parse(cli::Parser& parser) + { + size = parser.get("size"); + + seed_type = parser.get("seed"); + + seed = managed_seed(seed_type); + + batch_iterations = parser.get("batch_iterations"); + warmup_iterations = parser.get("warmup_iterations"); + + cold = !parser.get("hot"); + record_as_whole = parser.get("record_as_whole"); + + trials = parser.get("trials"); + parallel_instance = parser.get("parallel_instance"); + parallel_instances = parser.get("parallel_instances"); + + bench_naming::set_format(parser.get("name_format")); + } + + void add_context() + { + benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("seed", seed_type); + + benchmark::AddCustomContext("batch_iterations", std::to_string(batch_iterations)); + benchmark::AddCustomContext("warmup_iterations", std::to_string(warmup_iterations)); + + hipDeviceProp_t devProp; + int device_id = 0; + HIP_CHECK(hipGetDevice(&device_id)); + HIP_CHECK(hipGetDeviceProperties(&devProp, device_id)); + + auto str = [](const std::string& name, const std::string& val) + { benchmark::AddCustomContext(name, val); }; + + auto num = [](const std::string& name, const auto& value) + { benchmark::AddCustomContext(name, std::to_string(value)); }; + + auto dim2 = [num](const std::string& name, const auto* values) + { + num(name + "_x", values[0]); + num(name + "_y", values[1]); + }; + + auto dim3 = [num, dim2](const std::string& name, const auto* values) + { + dim2(name, values); + num(name + "_z", values[2]); + }; + + str("hdp_name", devProp.name); + num("hdp_total_global_mem", devProp.totalGlobalMem); + num("hdp_shared_mem_per_block", devProp.sharedMemPerBlock); + num("hdp_regs_per_block", devProp.regsPerBlock); + num("hdp_warp_size", devProp.warpSize); + num("hdp_max_threads_per_block", devProp.maxThreadsPerBlock); + dim3("hdp_max_threads_dim", devProp.maxThreadsDim); + dim3("hdp_max_grid_size", devProp.maxGridSize); + num("hdp_clock_rate", devProp.clockRate); + num("hdp_memory_clock_rate", devProp.memoryClockRate); + num("hdp_memory_bus_width", devProp.memoryBusWidth); + num("hdp_total_const_mem", devProp.totalConstMem); + num("hdp_major", devProp.major); + num("hdp_minor", devProp.minor); + num("hdp_multi_processor_count", devProp.multiProcessorCount); + num("hdp_l2_cache_size", devProp.l2CacheSize); + num("hdp_max_threads_per_multiprocessor", devProp.maxThreadsPerMultiProcessor); + num("hdp_compute_mode", devProp.computeMode); + num("hdp_clock_instruction_rate", devProp.clockInstructionRate); + num("hdp_concurrent_kernels", devProp.concurrentKernels); + num("hdp_pci_domain_id", devProp.pciDomainID); + num("hdp_pci_bus_id", devProp.pciBusID); + num("hdp_pci_device_id", devProp.pciDeviceID); + num("hdp_max_shared_memory_per_multi_processor", devProp.maxSharedMemoryPerMultiProcessor); + num("hdp_is_multi_gpu_board", devProp.isMultiGpuBoard); + num("hdp_can_map_host_memory", devProp.canMapHostMemory); + str("hdp_gcn_arch_name", devProp.gcnArchName); + num("hdp_integrated", devProp.integrated); + num("hdp_cooperative_launch", devProp.cooperativeLaunch); + num("hdp_cooperative_multi_device_launch", devProp.cooperativeMultiDeviceLaunch); + num("hdp_max_texture_1d_linear", devProp.maxTexture1DLinear); + num("hdp_max_texture_1d", devProp.maxTexture1D); + dim2("hdp_max_texture_2d", devProp.maxTexture2D); + dim3("hdp_max_texture_3d", devProp.maxTexture3D); + num("hdp_mem_pitch", devProp.memPitch); + num("hdp_texture_alignment", devProp.textureAlignment); + num("hdp_texture_pitch_alignment", devProp.texturePitchAlignment); + num("hdp_kernel_exec_timeout_enabled", devProp.kernelExecTimeoutEnabled); + num("hdp_ecc_enabled", devProp.ECCEnabled); + num("hdp_tcc_driver", devProp.tccDriver); + num("hdp_cooperative_multi_device_unmatched_func", + devProp.cooperativeMultiDeviceUnmatchedFunc); + num("hdp_cooperative_multi_device_unmatched_grid_dim", + devProp.cooperativeMultiDeviceUnmatchedGridDim); + num("hdp_cooperative_multi_device_unmatched_block_dim", + devProp.cooperativeMultiDeviceUnmatchedBlockDim); + num("hdp_cooperative_multi_device_unmatched_shared_mem", + devProp.cooperativeMultiDeviceUnmatchedSharedMem); + num("hdp_is_large_bar", devProp.isLargeBar); + num("hdp_asic_revision", devProp.asicRevision); + num("hdp_managed_memory", devProp.managedMemory); + num("hdp_direct_managed_mem_access_from_host", devProp.directManagedMemAccessFromHost); + num("hdp_concurrent_managed_access", devProp.concurrentManagedAccess); + num("hdp_pageable_memory_access", devProp.pageableMemoryAccess); + num("hdp_pageable_memory_access_uses_host_page_tables", + devProp.pageableMemoryAccessUsesHostPageTables); + + const auto arch = devProp.arch; + num("hdp_arch_has_global_int32_atomics", arch.hasGlobalInt32Atomics); + num("hdp_arch_has_global_float_atomic_exch", arch.hasGlobalFloatAtomicExch); + num("hdp_arch_has_shared_int32_atomics", arch.hasSharedInt32Atomics); + num("hdp_arch_has_shared_float_atomic_exch", arch.hasSharedFloatAtomicExch); + num("hdp_arch_has_float_atomic_add", arch.hasFloatAtomicAdd); + num("hdp_arch_has_global_int64_atomics", arch.hasGlobalInt64Atomics); + num("hdp_arch_has_shared_int64_atomics", arch.hasSharedInt64Atomics); + num("hdp_arch_has_doubles", arch.hasDoubles); + num("hdp_arch_has_warp_vote", arch.hasWarpVote); + num("hdp_arch_has_warp_ballot", arch.hasWarpBallot); + num("hdp_arch_has_warp_shuffle", arch.hasWarpShuffle); + num("hdp_arch_has_funnel_shift", arch.hasFunnelShift); + num("hdp_arch_has_thread_fence_system", arch.hasThreadFenceSystem); + num("hdp_arch_has_sync_threads_ext", arch.hasSyncThreadsExt); + num("hdp_arch_has_surface_funcs", arch.hasSurfaceFuncs); + num("hdp_arch_has_3d_grid", arch.has3dGrid); + num("hdp_arch_has_dynamic_parallelism", arch.hasDynamicParallelism); + } + + static std::vector>& sorted_benchmarks() + { + static std::vector> sorted_benchmarks; + return sorted_benchmarks; + } + + state new_state(benchmark::State& gbench_state) + { + return state(stream, + size, + seed, + batch_iterations, + gbench_state, + warmup_iterations, + cold, + record_as_whole); + } + + void apply_settings(benchmark::internal::Benchmark* b) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + + // trials is -1 by default. + if(trials > 0) + { + b->Iterations(trials); + } + } + + // Register a subset of all benchmarks for the current parallel instance. + void register_sorted_subset(int parallel_instance_index, int parallel_instance_count) + { + // Sort to get a consistent order, because the order of static variable initialization is undefined by the C++ standard. + std::sort(sorted_benchmarks().begin(), + sorted_benchmarks().end(), + [](const auto& l, const auto& r) { return l->sort_key() < r->sort_key(); }); + + size_t configs_per_instance + = (sorted_benchmarks().size() + parallel_instance_count - 1) / parallel_instance_count; + size_t start + = std::min(parallel_instance_index * configs_per_instance, sorted_benchmarks().size()); + size_t end = std::min((parallel_instance_index + 1) * configs_per_instance, + sorted_benchmarks().size()); + + for(size_t i = start; i < end; ++i) + { + autotune_interface* benchmark = sorted_benchmarks().at(i).get(); + + apply_settings(benchmark::RegisterBenchmark( + benchmark->name().c_str(), + [=](benchmark::State& gbench_state) { benchmark->run(new_state(gbench_state)); })); + } + } + + hipStream_t stream = hipStreamDefault; + size_t size; + std::string seed_type; + managed_seed seed; + size_t batch_iterations; + size_t warmup_iterations; + bool cold; + bool record_as_whole; + + int trials; + int parallel_instance; + int parallel_instances; +}; + +} // namespace benchmark_utils + #endif // ROCPRIM_BENCHMARK_UTILS_HPP_ diff --git a/benchmark/benchmark_warp_exchange.cpp b/benchmark/benchmark_warp_exchange.cpp index 5aafec1e1..152ff0020 100644 --- a/benchmark/benchmark_warp_exchange.cpp +++ b/benchmark/benchmark_warp_exchange.cpp @@ -21,13 +21,11 @@ // SOFTWARE. #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #include "../common/utils.hpp" +#include "../common/utils_device_ptr.hpp" #include "../common/warp_exchange.hpp" -// Google Benchmark #include // HIP API @@ -44,10 +42,6 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - struct ScatterToStripedOp { template @@ -171,286 +165,219 @@ template -void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) +void run_benchmark(benchmark_utils::state&& state) { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + // Calculate the number of elements size_t N = bytes / sizeof(T); - constexpr unsigned int trials = 200; - constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; - const unsigned int size = items_per_block * ((N + items_per_block - 1) / items_per_block); - - T* d_output; - HIP_CHECK(hipMalloc(&d_output, size * sizeof(T))); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - warp_exchange_kernel - <<>>(d_output, trials); - - HIP_CHECK(hipPeekAtLastError()); + constexpr uint64_t trials = 200; + constexpr uint64_t items_per_block = BlockSize * ItemsPerThread; + const uint64_t size = items_per_block * ((N + items_per_block - 1) / items_per_block); - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } + common::device_ptr d_output(size); - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * trials * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * trials * size); + state.run( + [&] + { + warp_exchange_kernel + <<>>(d_output.get(), + trials); + }); - HIP_CHECK(hipFree(d_output)); + state.set_throughput(trials * size, sizeof(T)); } -#define CREATE_BENCHMARK(T, BS, IT, WS, OP) \ - benchmark::RegisterBenchmark(bench_naming::format_name("{lvl:warp,algo:exchange,key_type:" #T \ - ",operation:" #OP ",ws:" #WS \ - ",cfg:{bs:" #BS ",ipt:" #IT "}}") \ - .c_str(), \ - &run_benchmark, \ - stream, \ - bytes) +#define CREATE_BENCHMARK(T, BS, IT, WS, OP) \ + executor.queue_fn(bench_naming::format_name("{lvl:warp,algo:exchange,key_type:" #T \ + ",operation:" #OP ",ws:" #WS ",cfg:{bs:" #BS \ + ",ipt:" #IT "}}") \ + .c_str(), \ + run_benchmark); int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - - // Add benchmarks - std::vector benchmarks{ - CREATE_BENCHMARK(int, 256, 1, 16, common::BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 1, 32, common::BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 4, 16, common::BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 4, 32, common::BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 16, 16, common::BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 16, 32, common::BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 32, 32, common::BlockedToStripedOp), - - CREATE_BENCHMARK(int, 256, 1, 16, common::StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 1, 32, common::StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 4, 16, common::StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 4, 32, common::StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 16, 16, common::StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 16, 32, common::StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 32, 32, common::StripedToBlockedOp), - - CREATE_BENCHMARK(int, 256, 1, 16, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 1, 32, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 4, 16, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 4, 32, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 16, 16, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 16, 32, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 32, 32, common::BlockedToStripedShuffleOp), - - CREATE_BENCHMARK(int, 256, 1, 16, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 1, 32, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 4, 16, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 4, 32, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 16, 16, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 16, 32, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 32, 32, common::StripedToBlockedShuffleOp), - - CREATE_BENCHMARK(int, 256, 1, 16, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 1, 32, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 4, 16, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 4, 32, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 16, 16, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 16, 32, ScatterToStripedOp), - - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 16, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 32, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 16, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 32, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 16, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 32, common::BlockedToStripedOp), - - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 16, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 32, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 16, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 32, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 16, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 32, common::StripedToBlockedOp), - - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 16, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 32, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 16, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 32, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 16, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 32, common::BlockedToStripedShuffleOp), - - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 16, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 32, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 16, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 32, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 16, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 32, common::StripedToBlockedShuffleOp), - - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 16, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 32, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 16, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 32, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 16, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 32, ScatterToStripedOp), - - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 16, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 32, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 16, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 32, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 16, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 32, common::BlockedToStripedOp), - - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 16, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 32, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 16, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 32, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 16, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 32, common::StripedToBlockedOp), - - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 16, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 32, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 16, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 32, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 16, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 32, common::BlockedToStripedShuffleOp), - - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 16, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 32, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 16, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 32, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 16, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 32, common::StripedToBlockedShuffleOp), - - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 16, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 32, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 16, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 32, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 16, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 32, ScatterToStripedOp)}; + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 1, 0); + + CREATE_BENCHMARK(int, 256, 1, 16, common::BlockedToStripedOp) + CREATE_BENCHMARK(int, 256, 1, 32, common::BlockedToStripedOp) + CREATE_BENCHMARK(int, 256, 4, 16, common::BlockedToStripedOp) + CREATE_BENCHMARK(int, 256, 4, 32, common::BlockedToStripedOp) + CREATE_BENCHMARK(int, 256, 16, 16, common::BlockedToStripedOp) + CREATE_BENCHMARK(int, 256, 16, 32, common::BlockedToStripedOp) + CREATE_BENCHMARK(int, 256, 32, 32, common::BlockedToStripedOp) + + CREATE_BENCHMARK(int, 256, 1, 16, common::StripedToBlockedOp) + CREATE_BENCHMARK(int, 256, 1, 32, common::StripedToBlockedOp) + CREATE_BENCHMARK(int, 256, 4, 16, common::StripedToBlockedOp) + CREATE_BENCHMARK(int, 256, 4, 32, common::StripedToBlockedOp) + CREATE_BENCHMARK(int, 256, 16, 16, common::StripedToBlockedOp) + CREATE_BENCHMARK(int, 256, 16, 32, common::StripedToBlockedOp) + CREATE_BENCHMARK(int, 256, 32, 32, common::StripedToBlockedOp) + + CREATE_BENCHMARK(int, 256, 1, 16, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(int, 256, 1, 32, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(int, 256, 4, 16, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(int, 256, 4, 32, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(int, 256, 16, 16, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(int, 256, 16, 32, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(int, 256, 32, 32, common::BlockedToStripedShuffleOp) + + CREATE_BENCHMARK(int, 256, 1, 16, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(int, 256, 1, 32, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(int, 256, 4, 16, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(int, 256, 4, 32, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(int, 256, 16, 16, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(int, 256, 16, 32, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(int, 256, 32, 32, common::StripedToBlockedShuffleOp) + + CREATE_BENCHMARK(int, 256, 1, 16, ScatterToStripedOp) + CREATE_BENCHMARK(int, 256, 1, 32, ScatterToStripedOp) + CREATE_BENCHMARK(int, 256, 4, 16, ScatterToStripedOp) + CREATE_BENCHMARK(int, 256, 4, 32, ScatterToStripedOp) + CREATE_BENCHMARK(int, 256, 16, 16, ScatterToStripedOp) + CREATE_BENCHMARK(int, 256, 16, 32, ScatterToStripedOp) + + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 16, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 32, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 16, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 32, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 16, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 32, common::BlockedToStripedOp) + + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 16, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 32, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 16, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 32, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 16, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 32, common::StripedToBlockedOp) + + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 16, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 32, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 16, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 32, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 16, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 32, common::BlockedToStripedShuffleOp) + + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 16, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 32, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 16, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 32, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 16, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 32, common::StripedToBlockedShuffleOp) + + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 16, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 32, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 16, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 32, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 16, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 32, ScatterToStripedOp) + + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 16, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 32, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 16, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 32, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 16, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 32, common::BlockedToStripedOp) + + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 16, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 32, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 16, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 32, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 16, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 32, common::StripedToBlockedOp) + + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 16, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 32, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 16, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 32, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 16, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 32, common::BlockedToStripedShuffleOp) + + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 16, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 32, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 16, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 32, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 16, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 32, common::StripedToBlockedShuffleOp) + + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 16, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 32, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 16, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 32, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 16, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 32, ScatterToStripedOp) int hip_device = 0; - HIP_CHECK(::rocprim::detail::get_device_from_stream(stream, hip_device)); + HIP_CHECK(::rocprim::detail::get_device_from_stream(hipStreamDefault, hip_device)); if(is_warp_size_supported(64, hip_device)) { - std::vector additional_benchmarks{ - CREATE_BENCHMARK(int, 256, 1, 64, common::BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 4, 64, common::BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 16, 64, common::BlockedToStripedOp), - CREATE_BENCHMARK(int, 256, 64, 64, common::BlockedToStripedOp), - - CREATE_BENCHMARK(int, 256, 1, 64, common::StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 4, 64, common::StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 16, 64, common::StripedToBlockedOp), - CREATE_BENCHMARK(int, 256, 64, 64, common::StripedToBlockedOp), - - CREATE_BENCHMARK(int, 256, 1, 64, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 4, 64, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 16, 64, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(int, 256, 64, 64, common::BlockedToStripedShuffleOp), - - CREATE_BENCHMARK(int, 256, 1, 64, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 4, 64, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 16, 64, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(int, 256, 64, 64, common::StripedToBlockedShuffleOp), - - CREATE_BENCHMARK(int, 256, 1, 64, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 4, 64, ScatterToStripedOp), - CREATE_BENCHMARK(int, 256, 16, 64, ScatterToStripedOp), - - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 64, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 64, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 64, common::BlockedToStripedOp), - - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 64, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 64, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 64, common::StripedToBlockedOp), - - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 64, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 64, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 64, common::BlockedToStripedShuffleOp), - - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 64, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 64, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 64, common::StripedToBlockedShuffleOp), - - CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 64, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 64, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 64, ScatterToStripedOp), - - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 64, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 64, common::BlockedToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 64, common::BlockedToStripedOp), - - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 64, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 64, common::StripedToBlockedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 64, common::StripedToBlockedOp), - - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 64, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 64, common::BlockedToStripedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 64, common::BlockedToStripedShuffleOp), - - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 64, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 64, common::StripedToBlockedShuffleOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 64, common::StripedToBlockedShuffleOp), - - CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 64, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 64, ScatterToStripedOp), - CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 64, ScatterToStripedOp)}; - benchmarks.insert(benchmarks.end(), - additional_benchmarks.begin(), - additional_benchmarks.end()); - } - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } + CREATE_BENCHMARK(int, 256, 1, 64, common::BlockedToStripedOp) + CREATE_BENCHMARK(int, 256, 4, 64, common::BlockedToStripedOp) + CREATE_BENCHMARK(int, 256, 16, 64, common::BlockedToStripedOp) + CREATE_BENCHMARK(int, 256, 64, 64, common::BlockedToStripedOp) + + CREATE_BENCHMARK(int, 256, 1, 64, common::StripedToBlockedOp) + CREATE_BENCHMARK(int, 256, 4, 64, common::StripedToBlockedOp) + CREATE_BENCHMARK(int, 256, 16, 64, common::StripedToBlockedOp) + CREATE_BENCHMARK(int, 256, 64, 64, common::StripedToBlockedOp) + + CREATE_BENCHMARK(int, 256, 1, 64, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(int, 256, 4, 64, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(int, 256, 16, 64, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(int, 256, 64, 64, common::BlockedToStripedShuffleOp) + + CREATE_BENCHMARK(int, 256, 1, 64, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(int, 256, 4, 64, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(int, 256, 16, 64, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(int, 256, 64, 64, common::StripedToBlockedShuffleOp) + + CREATE_BENCHMARK(int, 256, 1, 64, ScatterToStripedOp) + CREATE_BENCHMARK(int, 256, 4, 64, ScatterToStripedOp) + CREATE_BENCHMARK(int, 256, 16, 64, ScatterToStripedOp) + + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 64, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 64, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 64, common::BlockedToStripedOp) + + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 64, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 64, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 64, common::StripedToBlockedOp) + + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 64, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 64, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 64, common::BlockedToStripedShuffleOp) + + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 64, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 64, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 64, common::StripedToBlockedShuffleOp) + + CREATE_BENCHMARK(rocprim::int128_t, 256, 1, 64, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 4, 64, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::int128_t, 256, 16, 64, ScatterToStripedOp) + + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 64, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 64, common::BlockedToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 64, common::BlockedToStripedOp) + + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 64, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 64, common::StripedToBlockedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 64, common::StripedToBlockedOp) + + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 64, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 64, common::BlockedToStripedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 64, common::BlockedToStripedShuffleOp) + + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 64, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 64, common::StripedToBlockedShuffleOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 64, common::StripedToBlockedShuffleOp) + + CREATE_BENCHMARK(rocprim::uint128_t, 256, 1, 64, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 4, 64, ScatterToStripedOp) + CREATE_BENCHMARK(rocprim::uint128_t, 256, 16, 64, ScatterToStripedOp) } - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_warp_reduce.cpp b/benchmark/benchmark_warp_reduce.cpp index 87a909225..d580882c2 100644 --- a/benchmark/benchmark_warp_reduce.cpp +++ b/benchmark/benchmark_warp_reduce.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,10 +21,9 @@ // SOFTWARE. #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" -// Google Benchmark +#include "../common/utils_device_ptr.hpp" + #include // HIP API @@ -41,52 +40,54 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - -template +template __global__ __launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) void warp_reduce_kernel(const T* d_input, T* d_output) { - const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + if constexpr(VirtualWaveSize <= rocprim::arch::wavefront::max_size()) + { + const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - auto value = d_input[i]; + auto value = d_input[i]; - using wreduce_t = rocprim::warp_reduce; - __shared__ typename wreduce_t::storage_type storage; - ROCPRIM_NO_UNROLL - for(unsigned int trial = 0; trial < Trials; ++trial) - { - wreduce_t().reduce(value, value, storage); - } + using wreduce_t = rocprim::warp_reduce; + __shared__ typename wreduce_t::storage_type storage; + ROCPRIM_NO_UNROLL + for(unsigned int trial = 0; trial < Trials; ++trial) + { + wreduce_t().reduce(value, value, storage); + } - d_output[i] = value; + d_output[i] = value; + } } -template +template __global__ __launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) void segmented_warp_reduce_kernel(const T* d_input, Flag* d_flags, T* d_output) { - const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + if constexpr(VirtualWaveSize <= rocprim::arch::wavefront::max_size()) + { + const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - auto value = d_input[i]; - auto flag = d_flags[i]; + auto value = d_input[i]; + auto flag = d_flags[i]; - using wreduce_t = rocprim::warp_reduce; - __shared__ typename wreduce_t::storage_type storage; - ROCPRIM_NO_UNROLL - for(unsigned int trial = 0; trial < Trials; ++trial) - { - wreduce_t().head_segmented_reduce(value, value, flag, storage); - } + using wreduce_t = rocprim::warp_reduce; + __shared__ typename wreduce_t::storage_type storage; + ROCPRIM_NO_UNROLL + for(unsigned int trial = 0; trial < Trials; ++trial) + { + wreduce_t().head_segmented_reduce(value, value, flag, storage); + } - d_output[i] = value; + d_output[i] = value; + } } template typename std::enable_if::type { - hipLaunchKernelGGL(HIP_KERNEL_NAME(warp_reduce_kernel), + hipLaunchKernelGGL(HIP_KERNEL_NAME(warp_reduce_kernel), dim3(size / BlockSize), dim3(BlockSize), 0, @@ -107,7 +108,7 @@ inline auto execute_warp_reduce_kernel( template typename std::enable_if::type { - hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_warp_reduce_kernel), - dim3(size / BlockSize), - dim3(BlockSize), - 0, - stream, - input, - flags, - output); + hipLaunchKernelGGL( + HIP_KERNEL_NAME(segmented_warp_reduce_kernel), + dim3(size / BlockSize), + dim3(BlockSize), + 0, + stream, + input, + flags, + output); HIP_CHECK(hipGetLastError()); } template -void run_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +void run_benchmark(benchmark_utils::state&& state) { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + using flag_type = unsigned char; // Calculate the number of elements @@ -148,139 +151,63 @@ void run_benchmark(benchmark::State& state, const auto random_range = limit_random_range(0, 10); std::vector input = get_random_data(size, random_range.first, random_range.second, seed.get_0()); - std::vector flags = get_random_data(size, 0, 1, seed.get_1()); - T* d_input; - flag_type* d_flags; - T* d_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(T))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_flags), size * sizeof(flag_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), size * sizeof(T))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(d_flags, flags.data(), size * sizeof(flag_type), hipMemcpyHostToDevice)); + std::vector flags = get_random_data(size, 0, 1, seed.get_1()); + common::device_ptr d_input(input); + common::device_ptr d_flags(flags); + common::device_ptr d_output(size); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - execute_warp_reduce_kernel(d_input, - d_output, - d_flags, - size, - stream); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * Trials * size * sizeof(T)); - state.SetItemsProcessed(state.iterations() * Trials * size); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); - HIP_CHECK(hipFree(d_flags)); + state.run( + [&] + { + execute_warp_reduce_kernel( + d_input.get(), + d_output.get(), + d_flags.get(), + size, + stream); + }); + + state.set_throughput(Trials * size, sizeof(T)); } #define CREATE_BENCHMARK(T, WS, BS) \ - benchmark::RegisterBenchmark( \ + executor.queue_fn( \ bench_naming::format_name("{lvl:warp,algo:reduce,key_type:" #T ",broadcast_result:" \ + std::string(AllReduce ? "true" : "false") \ + ",segmented:" + std::string(Segmented ? "true" : "false") \ + ",ws:" #WS ",cfg:{bs:" #BS "}}") \ .c_str(), \ - run_benchmark, \ - bytes, \ - seed, \ - stream) + run_benchmark); -#define BENCHMARK_TYPE(type) \ - CREATE_BENCHMARK(type, 32, 64), CREATE_BENCHMARK(type, 37, 64), \ - CREATE_BENCHMARK(type, 61, 64), CREATE_BENCHMARK(type, 64, 64) +// clang-format off +#define BENCHMARK_TYPE(type) \ + CREATE_BENCHMARK(type, 32, 64) \ + CREATE_BENCHMARK(type, 37, 64) \ + CREATE_BENCHMARK(type, 61, 64) \ + CREATE_BENCHMARK(type, 64, 64) +// clang-format on template -void add_benchmarks(std::vector& benchmarks, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +void add_benchmarks(benchmark_utils::executor& executor) { - std::vector bs = {BENCHMARK_TYPE(int), - BENCHMARK_TYPE(float), - BENCHMARK_TYPE(double), - BENCHMARK_TYPE(int8_t), - BENCHMARK_TYPE(uint8_t), - BENCHMARK_TYPE(rocprim::half), - BENCHMARK_TYPE(rocprim::int128_t), - BENCHMARK_TYPE(rocprim::uint128_t)}; - - benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); + BENCHMARK_TYPE(int) + BENCHMARK_TYPE(float) + BENCHMARK_TYPE(double) + BENCHMARK_TYPE(int8_t) + BENCHMARK_TYPE(uint8_t) + BENCHMARK_TYPE(rocprim::half) + BENCHMARK_TYPE(rocprim::int128_t) + BENCHMARK_TYPE(rocprim::uint128_t) } int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 1, 0); - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); - - // Add benchmarks - std::vector benchmarks; - add_benchmarks(benchmarks, bytes, seed, stream); - add_benchmarks(benchmarks, bytes, seed, stream); - add_benchmarks(benchmarks, bytes, seed, stream); - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + add_benchmarks(executor); + add_benchmarks(executor); + add_benchmarks(executor); - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_warp_scan.cpp b/benchmark/benchmark_warp_scan.cpp index c17a0c4c5..7a814c320 100644 --- a/benchmark/benchmark_warp_scan.cpp +++ b/benchmark/benchmark_warp_scan.cpp @@ -21,13 +21,12 @@ // SOFTWARE. #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #include "../common/utils_custom_type.hpp" +#include "../common/utils_device_ptr.hpp" -// Google Benchmark #include + // HIP API #include // rocPRIM @@ -40,10 +39,6 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - enum class scan_type { inclusive_scan, @@ -51,16 +46,19 @@ enum class scan_type broadcast }; -template +template __global__ __launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) void kernel(const T* input, T* output, const T init) { - Runner::template run(input, output, init); + if constexpr(VirtualWaveSize <= rocprim::arch::wavefront::max_size()) + { + Runner::template run(input, output, init); + } } struct inclusive_scan { - template + template __device__ static void run(const T* input, T* output, const T init) { @@ -68,7 +66,7 @@ struct inclusive_scan const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; auto value = input[i]; - using wscan_t = rocprim::warp_scan; + using wscan_t = rocprim::warp_scan; __shared__ typename wscan_t::storage_type storage; ROCPRIM_NO_UNROLL for(unsigned int trial = 0; trial < Trials; ++trial) @@ -82,14 +80,14 @@ struct inclusive_scan struct exclusive_scan { - template + template __device__ static void run(const T* input, T* output, const T init) { const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; auto value = input[i]; - using wscan_t = rocprim::warp_scan; + using wscan_t = rocprim::warp_scan; __shared__ typename wscan_t::storage_type storage; ROCPRIM_NO_UNROLL for(unsigned int trial = 0; trial < Trials; ++trial) @@ -103,17 +101,17 @@ struct exclusive_scan struct broadcast { - template + template __device__ static void run(const T* input, T* output, const T init) { (void)init; const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - const unsigned int warp_id = i / WarpSize; - const unsigned int src_lane = warp_id % WarpSize; + const unsigned int warp_id = i / VirtualWaveSize; + const unsigned int src_lane = warp_id % VirtualWaveSize; auto value = input[i]; - using wscan_t = rocprim::warp_scan; + using wscan_t = rocprim::warp_scan; __shared__ typename wscan_t::storage_type storage; ROCPRIM_NO_UNROLL for(unsigned int trial = 0; trial < Trials; ++trial) @@ -127,182 +125,114 @@ struct broadcast template -void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) +void run_benchmark(benchmark_utils::state&& state) { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + // Calculate the number of elements size_t size = bytes / sizeof(T); // Make sure size is a multiple of BlockSize size = BlockSize * ((size + BlockSize - 1) / BlockSize); // Allocate and fill memory - std::vector input(size, (T)1); - T* d_input; - T* d_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_input), size * sizeof(T))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_output), size * sizeof(T))); - HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(T), hipMemcpyHostToDevice)); + std::vector input(size, (T)1); + common::device_ptr d_input(input); + common::device_ptr d_output(size); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel), - dim3(size / BlockSize), - dim3(BlockSize), - 0, - stream, - d_input, - d_output, - input[0]); - HIP_CHECK(hipGetLastError()); - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * size * sizeof(T) * Trials); - state.SetItemsProcessed(state.iterations() * size * Trials); - - HIP_CHECK(hipFree(d_input)); - HIP_CHECK(hipFree(d_output)); + state.run( + [&] + { + hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel), + dim3(size / BlockSize), + dim3(BlockSize), + 0, + stream, + d_input.get(), + d_output.get(), + input[0]); + }); + + state.set_throughput(Trials * size, sizeof(T)); } -#define CREATE_BENCHMARK_IMPL(T, BS, WS, OP) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:warp,algo:scan,key_type:" #T ",subalgo:" + method_name \ - + ",ws:" #WS ",cfg:{bs:" #BS "}}") \ - .c_str(), \ - run_benchmark, \ - stream, \ - bytes) - -#define CREATE_BENCHMARK(T, BS, WS) CREATE_BENCHMARK_IMPL(T, BS, WS, Benchmark) - -#define BENCHMARK_TYPE(type) \ - CREATE_BENCHMARK(type, 64, 64), CREATE_BENCHMARK(type, 128, 64), \ - CREATE_BENCHMARK(type, 256, 64), CREATE_BENCHMARK(type, 256, 32), \ - CREATE_BENCHMARK(type, 256, 16), CREATE_BENCHMARK(type, 63, 63), \ - CREATE_BENCHMARK(type, 62, 31), CREATE_BENCHMARK(type, 60, 15) - -#define BENCHMARK_TYPE_P2(type) \ - CREATE_BENCHMARK(type, 64, 64), CREATE_BENCHMARK(type, 128, 64), \ - CREATE_BENCHMARK(type, 256, 64), CREATE_BENCHMARK(type, 256, 32), \ - CREATE_BENCHMARK(type, 256, 16) +#define CREATE_BENCHMARK(T, BS, WS) \ + executor.queue_fn(bench_naming::format_name("{lvl:warp,algo:scan,key_type:" #T ",subalgo:" \ + + method_name + ",ws:" #WS ",cfg:{bs:" #BS "}}") \ + .c_str(), \ + run_benchmark); + +// clang-format off +#define BENCHMARK_TYPE(type) \ + CREATE_BENCHMARK(type, 64, 64) \ + CREATE_BENCHMARK(type, 128, 64) \ + CREATE_BENCHMARK(type, 256, 64) \ + CREATE_BENCHMARK(type, 256, 32) \ + CREATE_BENCHMARK(type, 256, 16) \ + CREATE_BENCHMARK(type, 63, 63) \ + CREATE_BENCHMARK(type, 62, 31) \ + CREATE_BENCHMARK(type, 60, 15) +// clang-format on + +// clang-format off +#define BENCHMARK_TYPE_P2(type) \ + CREATE_BENCHMARK(type, 64, 64) \ + CREATE_BENCHMARK(type, 128, 64) \ + CREATE_BENCHMARK(type, 256, 64) \ + CREATE_BENCHMARK(type, 256, 32) \ + CREATE_BENCHMARK(type, 256, 16) +// clang-format on template -auto add_benchmarks(std::vector& benchmarks, - const std::string& method_name, - hipStream_t stream, - size_t bytes) +auto add_benchmarks(benchmark_utils::executor& executor, const std::string& method_name) -> std::enable_if_t::value || std::is_same::value> { using custom_double2 = common::custom_type; using custom_int_double = common::custom_type; - std::vector new_benchmarks - = {BENCHMARK_TYPE(int), - BENCHMARK_TYPE(float), - BENCHMARK_TYPE(double), - BENCHMARK_TYPE(int8_t), - BENCHMARK_TYPE(uint8_t), - BENCHMARK_TYPE(rocprim::half), - BENCHMARK_TYPE(custom_double2), - BENCHMARK_TYPE(custom_int_double), - BENCHMARK_TYPE(rocprim::int128_t), - BENCHMARK_TYPE(rocprim::uint128_t)}; - benchmarks.insert(benchmarks.end(), new_benchmarks.begin(), new_benchmarks.end()); + BENCHMARK_TYPE(int) + BENCHMARK_TYPE(float) + BENCHMARK_TYPE(double) + BENCHMARK_TYPE(int8_t) + BENCHMARK_TYPE(uint8_t) + BENCHMARK_TYPE(rocprim::half) + BENCHMARK_TYPE(custom_double2) + BENCHMARK_TYPE(custom_int_double) + BENCHMARK_TYPE(rocprim::int128_t) + BENCHMARK_TYPE(rocprim::uint128_t) } template -auto add_benchmarks(std::vector& benchmarks, - const std::string& method_name, - hipStream_t stream, - size_t bytes) -> std::enable_if_t::value> +auto add_benchmarks(benchmark_utils::executor& executor, const std::string& method_name) + -> std::enable_if_t::value> { using custom_double2 = common::custom_type; using custom_int_double = common::custom_type; - std::vector new_benchmarks - = {BENCHMARK_TYPE_P2(int), - BENCHMARK_TYPE_P2(float), - BENCHMARK_TYPE_P2(double), - BENCHMARK_TYPE_P2(int8_t), - BENCHMARK_TYPE_P2(uint8_t), - BENCHMARK_TYPE_P2(rocprim::half), - BENCHMARK_TYPE_P2(custom_double2), - BENCHMARK_TYPE_P2(custom_int_double), - BENCHMARK_TYPE_P2(rocprim::int128_t), - BENCHMARK_TYPE_P2(rocprim::uint128_t)}; - benchmarks.insert(benchmarks.end(), new_benchmarks.begin(), new_benchmarks.end()); + BENCHMARK_TYPE_P2(int) + BENCHMARK_TYPE_P2(float) + BENCHMARK_TYPE_P2(double) + BENCHMARK_TYPE_P2(int8_t) + BENCHMARK_TYPE_P2(uint8_t) + BENCHMARK_TYPE_P2(rocprim::half) + BENCHMARK_TYPE_P2(custom_double2) + BENCHMARK_TYPE_P2(custom_int_double) + BENCHMARK_TYPE_P2(rocprim::int128_t) + BENCHMARK_TYPE_P2(rocprim::uint128_t) } int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - - // Add benchmarks - std::vector benchmarks; - add_benchmarks(benchmarks, "inclusive_scan", stream, bytes); //inclusive - add_benchmarks(benchmarks, "exclusive_scan", stream, bytes); //exclusive - add_benchmarks(benchmarks, "broadcast", stream, bytes); //broadcast + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 1, 0); - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } + add_benchmarks(executor, "inclusive_scan"); + add_benchmarks(executor, "exclusive_scan"); + add_benchmarks(executor, "broadcast"); - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + executor.run(); } diff --git a/benchmark/benchmark_warp_sort.cpp b/benchmark/benchmark_warp_sort.cpp index 4a060a224..8b0f9c877 100644 --- a/benchmark/benchmark_warp_sort.cpp +++ b/benchmark/benchmark_warp_sort.cpp @@ -21,13 +21,11 @@ // SOFTWARE. #include "benchmark_utils.hpp" -// CmdParser -#include "cmdparser.hpp" #include "../common/utils_custom_type.hpp" -// Google Benchmark #include + // HIP API #include // rocPRIM @@ -42,10 +40,6 @@ #include #include -#ifndef DEFAULT_BYTES -const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; -#endif - template __global__ __launch_bounds__(BlockSize) void warp_sort_kernel(K* input_keys, K* output_keys) @@ -95,11 +89,12 @@ template -void run_benchmark(benchmark::State& state, - size_t bytes, - const managed_seed& seed, - hipStream_t stream) +void run_benchmark(benchmark_utils::state&& state) { + const auto& stream = state.stream; + const auto& bytes = state.bytes; + const auto& seed = state.seed; + // Calculate the number of elements size_t size = bytes / sizeof(Key); @@ -136,70 +131,49 @@ void run_benchmark(benchmark::State& state, hipMemcpyHostToDevice)); HIP_CHECK(hipDeviceSynchronize()); - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for(auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - if(SortByKey) + state.run( + [&] { - ROCPRIM_NO_UNROLL - for(unsigned int trial = 0; trial < Trials; ++trial) + if(SortByKey) { - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - warp_sort_by_key_kernel), - dim3(size / items_per_block), - dim3(BlockSize), - 0, - stream, - d_input_key, - d_input_value, - d_output_key, - d_output_value); + ROCPRIM_NO_UNROLL + for(unsigned int trial = 0; trial < Trials; ++trial) + { + hipLaunchKernelGGL(HIP_KERNEL_NAME(warp_sort_by_key_kernel), + dim3(size / items_per_block), + dim3(BlockSize), + 0, + stream, + d_input_key, + d_input_value, + d_output_key, + d_output_value); + } } - } - else - { - ROCPRIM_NO_UNROLL - for(unsigned int trial = 0; trial < Trials; ++trial) + else { - hipLaunchKernelGGL( - HIP_KERNEL_NAME(warp_sort_kernel), - dim3(size / items_per_block), - dim3(BlockSize), - 0, - stream, - d_input_key, - d_output_key); + ROCPRIM_NO_UNROLL + for(unsigned int trial = 0; trial < Trials; ++trial) + { + hipLaunchKernelGGL( + HIP_KERNEL_NAME(warp_sort_kernel), + dim3(size / items_per_block), + dim3(BlockSize), + 0, + stream, + d_input_key, + d_output_key); + } } - } - HIP_CHECK(hipGetLastError()); + }); - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); + auto type_size = SortByKey ? sizeof(Key) + sizeof(Value) : sizeof(Key); - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - // SortByKey also transfers values - auto sorted_type_size = sizeof(Key); - if(SortByKey) - sorted_type_size += sizeof(Value); - state.SetBytesProcessed(state.iterations() * size * sorted_type_size * Trials); - state.SetItemsProcessed(state.iterations() * size * Trials); + state.set_throughput(size * Trials, type_size); HIP_CHECK(hipFree(d_input_key)); HIP_CHECK(hipFree(d_output_key)); @@ -207,71 +181,51 @@ void run_benchmark(benchmark::State& state, HIP_CHECK(hipFree(d_output_value)); } -#define CREATE_SORT_BENCHMARK(K, BS, WS, IPT) \ - benchmark::RegisterBenchmark( \ - bench_naming::format_name("{lvl:warp,algo:sort,key_type:" #K ",value_type:" \ - + std::string(Traits::name()) \ - + ",ws:" #WS ",cfg:{bs:" #BS ",ipt:" #IPT "}}") \ - .c_str(), \ - run_benchmark, \ - bytes, \ - seed, \ - stream) - -#define CREATE_SORTBYKEY_BENCHMARK(K, V, BS, WS, IPT) \ - benchmark::RegisterBenchmark(bench_naming::format_name("{lvl:warp,algo:sort,key_type:" #K \ - ",value_type:" #V ",ws:" #WS \ - ",cfg:{bs:" #BS ",ipt:" #IPT "}}") \ - .c_str(), \ - run_benchmark, \ - bytes, \ - seed, \ - stream) - -#define BENCHMARK_TYPE(type) \ - CREATE_SORT_BENCHMARK(type, 64, 64, 1), CREATE_SORT_BENCHMARK(type, 64, 64, 2), \ - CREATE_SORT_BENCHMARK(type, 64, 64, 4), CREATE_SORT_BENCHMARK(type, 128, 64, 1), \ - CREATE_SORT_BENCHMARK(type, 128, 64, 2), CREATE_SORT_BENCHMARK(type, 128, 64, 4), \ - CREATE_SORT_BENCHMARK(type, 256, 64, 1), CREATE_SORT_BENCHMARK(type, 256, 64, 2), \ - CREATE_SORT_BENCHMARK(type, 256, 64, 4), CREATE_SORT_BENCHMARK(type, 64, 32, 1), \ - CREATE_SORT_BENCHMARK(type, 64, 32, 2), CREATE_SORT_BENCHMARK(type, 64, 16, 1), \ - CREATE_SORT_BENCHMARK(type, 64, 16, 2), CREATE_SORT_BENCHMARK(type, 64, 16, 4) - -#define BENCHMARK_KEY_TYPE(type, value) \ - CREATE_SORTBYKEY_BENCHMARK(type, value, 64, 64, 1), \ - CREATE_SORTBYKEY_BENCHMARK(type, value, 64, 64, 2), \ - CREATE_SORTBYKEY_BENCHMARK(type, value, 64, 64, 4), \ - CREATE_SORTBYKEY_BENCHMARK(type, value, 256, 64, 1), \ - CREATE_SORTBYKEY_BENCHMARK(type, value, 256, 64, 2), \ - CREATE_SORTBYKEY_BENCHMARK(type, value, 256, 64, 4) +#define CREATE_SORT_BENCHMARK(K, BS, WS, IPT) \ + executor.queue_fn(bench_naming::format_name("{lvl:warp,algo:sort,key_type:" #K ",value_type:" \ + + std::string(Traits::name()) \ + + ",ws:" #WS ",cfg:{bs:" #BS ",ipt:" #IPT "}}") \ + .c_str(), \ + run_benchmark); + +#define CREATE_SORTBYKEY_BENCHMARK(K, V, BS, WS, IPT) \ + executor.queue_fn(bench_naming::format_name("{lvl:warp,algo:sort,key_type:" #K \ + ",value_type:" #V ",ws:" #WS ",cfg:{bs:" #BS \ + ",ipt:" #IPT "}}") \ + .c_str(), \ + run_benchmark); + +// clang-format off +#define BENCHMARK_TYPE(type) \ + CREATE_SORT_BENCHMARK(type, 64, 64, 1) \ + CREATE_SORT_BENCHMARK(type, 64, 64, 2) \ + CREATE_SORT_BENCHMARK(type, 64, 64, 4) \ + CREATE_SORT_BENCHMARK(type, 128, 64, 1) \ + CREATE_SORT_BENCHMARK(type, 128, 64, 2) \ + CREATE_SORT_BENCHMARK(type, 128, 64, 4) \ + CREATE_SORT_BENCHMARK(type, 256, 64, 1) \ + CREATE_SORT_BENCHMARK(type, 256, 64, 2) \ + CREATE_SORT_BENCHMARK(type, 256, 64, 4) \ + CREATE_SORT_BENCHMARK(type, 64, 32, 1) \ + CREATE_SORT_BENCHMARK(type, 64, 32, 2) \ + CREATE_SORT_BENCHMARK(type, 64, 16, 1) \ + CREATE_SORT_BENCHMARK(type, 64, 16, 2) \ + CREATE_SORT_BENCHMARK(type, 64, 16, 4) +// clang-format on + +// clang-format off +#define BENCHMARK_KEY_TYPE(type, value) \ + CREATE_SORTBYKEY_BENCHMARK(type, value, 64, 64, 1) \ + CREATE_SORTBYKEY_BENCHMARK(type, value, 64, 64, 2) \ + CREATE_SORTBYKEY_BENCHMARK(type, value, 64, 64, 4) \ + CREATE_SORTBYKEY_BENCHMARK(type, value, 256, 64, 1) \ + CREATE_SORTBYKEY_BENCHMARK(type, value, 256, 64, 2) \ + CREATE_SORTBYKEY_BENCHMARK(type, value, 256, 64, 4) +// clang-format on int main(int argc, char* argv[]) { - cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); - parser.set_optional("trials", "trials", -1, "number of iterations"); - parser.set_optional("name_format", - "name_format", - "human", - "either: json,human,txt"); - parser.set_optional("seed", "seed", "random", get_seed_message()); - parser.run_and_exit_if_error(); - - // Parse argv - benchmark::Initialize(&argc, argv); - const size_t bytes = parser.get("size"); - const int trials = parser.get("trials"); - bench_naming::set_format(parser.get("name_format")); - const std::string seed_type = parser.get("seed"); - const managed_seed seed(seed_type); - - // HIP - hipStream_t stream = 0; // default - - // Benchmark info - add_common_benchmark_info(); - benchmark::AddCustomContext("bytes", std::to_string(bytes)); - benchmark::AddCustomContext("seed", seed_type); + benchmark_utils::executor executor(argc, argv, 128 * benchmark_utils::MiB, 1, 0); using custom_double2 = common::custom_type; using custom_int_double = common::custom_type; @@ -280,46 +234,27 @@ int main(int argc, char* argv[]) using custom_char_double = common::custom_type; using custom_longlong_double = common::custom_type; - std::vector benchmarks - = {BENCHMARK_TYPE(int), - BENCHMARK_TYPE(float), - BENCHMARK_TYPE(double), - BENCHMARK_TYPE(int8_t), - BENCHMARK_TYPE(uint8_t), - BENCHMARK_TYPE(rocprim::half), - BENCHMARK_TYPE(rocprim::int128_t), - BENCHMARK_TYPE(rocprim::uint128_t), - - BENCHMARK_KEY_TYPE(float, float), - BENCHMARK_KEY_TYPE(unsigned int, int), - BENCHMARK_KEY_TYPE(int, custom_double2), - BENCHMARK_KEY_TYPE(int, custom_int_double), - BENCHMARK_KEY_TYPE(custom_int2, custom_double2), - BENCHMARK_KEY_TYPE(custom_int2, custom_char_double), - BENCHMARK_KEY_TYPE(custom_int2, custom_longlong_double), - BENCHMARK_KEY_TYPE(int8_t, int8_t), - BENCHMARK_KEY_TYPE(uint8_t, uint8_t), - BENCHMARK_KEY_TYPE(rocprim::half, rocprim::half), - BENCHMARK_KEY_TYPE(rocprim::int128_t, rocprim::int128_t), - BENCHMARK_KEY_TYPE(rocprim::uint128_t, rocprim::uint128_t)}; - - // Use manual timing - for(auto& b : benchmarks) - { - b->UseManualTime(); - b->Unit(benchmark::kMillisecond); - } - - // Force number of iterations - if(trials > 0) - { - for(auto& b : benchmarks) - { - b->Iterations(trials); - } - } - - // Run benchmarks - benchmark::RunSpecifiedBenchmarks(); - return 0; + BENCHMARK_TYPE(int) + BENCHMARK_TYPE(float) + BENCHMARK_TYPE(double) + BENCHMARK_TYPE(int8_t) + BENCHMARK_TYPE(uint8_t) + BENCHMARK_TYPE(rocprim::half) + BENCHMARK_TYPE(rocprim::int128_t) + BENCHMARK_TYPE(rocprim::uint128_t) + + BENCHMARK_KEY_TYPE(float, float) + BENCHMARK_KEY_TYPE(unsigned int, int) + BENCHMARK_KEY_TYPE(int, custom_double2) + BENCHMARK_KEY_TYPE(int, custom_int_double) + BENCHMARK_KEY_TYPE(custom_int2, custom_double2) + BENCHMARK_KEY_TYPE(custom_int2, custom_char_double) + BENCHMARK_KEY_TYPE(custom_int2, custom_longlong_double) + BENCHMARK_KEY_TYPE(int8_t, int8_t) + BENCHMARK_KEY_TYPE(uint8_t, uint8_t) + BENCHMARK_KEY_TYPE(rocprim::half, rocprim::half) + BENCHMARK_KEY_TYPE(rocprim::int128_t, rocprim::int128_t) + BENCHMARK_KEY_TYPE(rocprim::uint128_t, rocprim::uint128_t) + + executor.run(); } diff --git a/common/utils.hpp b/common/utils.hpp index 6b4de63e9..e43ffb16a 100644 --- a/common/utils.hpp +++ b/common/utils.hpp @@ -55,9 +55,8 @@ namespace common { template -__device__ -constexpr bool device_test_enabled_for_warp_size_v - = ::rocprim::arch::wavefront::min_size() >= LogicalWarpSize; +__device__ constexpr bool device_test_enabled_for_warp_size_v + = ::rocprim::arch::wavefront::max_size() >= LogicalWarpSize; inline char* __get_env(const char* name) { diff --git a/common/utils_data_generation.hpp b/common/utils_data_generation.hpp index adb930830..e01e0125a 100644 --- a/common/utils_data_generation.hpp +++ b/common/utils_data_generation.hpp @@ -24,7 +24,6 @@ #define COMMON_UTILS_DATA_GENERATION_HPP_ #include -#include #include #include @@ -159,7 +158,9 @@ struct generate_limits }; template -struct generate_limits::value>> +struct generate_limits< + T, + std::enable_if_t().is_build_in() && rocprim::is_integral::value>> { static inline T min() { @@ -172,7 +173,9 @@ struct generate_limits::value>> }; template -struct generate_limits::value>> +struct generate_limits().is_build_in() + && rocprim::is_floating_point::value>> { static inline T min() { diff --git a/docs/concepts/type_traits.rst b/docs/concepts/type_traits.rst index 525519959..731b81e63 100644 --- a/docs/concepts/type_traits.rst +++ b/docs/concepts/type_traits.rst @@ -49,7 +49,7 @@ Available traits .. doxygenstruct:: rocprim::traits::float_bit_mask :members: -.. doxygenstruct:: rocprim::traits::is_fundamental +.. doxygenstruct:: rocprim::traits::radix_key_codec :members: Type traits wrappers diff --git a/docs/reference/intrinsics.rst b/docs/reference/intrinsics.rst index f53f9fce5..2fab8cb2e 100644 --- a/docs/reference/intrinsics.rst +++ b/docs/reference/intrinsics.rst @@ -8,6 +8,16 @@ Intrinsics ******************************************************************** +Hardware Architecture +===================== + +.. doxygenfunction:: rocprim::arch::wavefront::size() +.. doxygenfunction:: rocprim::arch::wavefront::min_size() +.. doxygenfunction:: rocprim::arch::wavefront::max_size() + +.. doxygenenum:: rocprim::arch::wavefront::target +.. doxygenfunction:: rocprim::arch::wavefront::target() +.. doxygenfunction:: rocprim::arch::wavefront::size_from_target() Bitwise ======== @@ -21,10 +31,8 @@ Bitwise Warp size =========== -.. doxygenfunction:: rocprim::warp_size() .. doxygenfunction:: rocprim::host_warp_size(const int device_id, unsigned int& warp_size) .. doxygenfunction:: rocprim::host_warp_size(const hipStream_t stream, unsigned int& warp_size) -.. doxygenfunction:: rocprim::device_warp_size() Lane and Warp ID ================= diff --git a/docs/thread_ops/radix_key_codec.rst b/docs/thread_ops/radix_key_codec.rst deleted file mode 100644 index 718324651..000000000 --- a/docs/thread_ops/radix_key_codec.rst +++ /dev/null @@ -1,12 +0,0 @@ -.. meta:: - :description: rocPRIM documentation and API reference library - :keywords: rocPRIM, ROCm, API, documentation - -.. _radix-key-codec: - -******************************************************************** - Radix Key Encoder/Decoder -******************************************************************** - -.. doxygenclass:: rocprim::radix_key_codec - :members: diff --git a/rocprim/include/rocprim/block/block_adjacent_difference.hpp b/rocprim/include/rocprim/block/block_adjacent_difference.hpp index f8c2154de..11519e178 100644 --- a/rocprim/include/rocprim/block/block_adjacent_difference.hpp +++ b/rocprim/include/rocprim/block/block_adjacent_difference.hpp @@ -121,758 +121,6 @@ class block_adjacent_difference using storage_type = storage_type_; #endif - /// \brief Tags \p head_flags that indicate discontinuities between items partitioned - /// across the thread block, where the first item has no reference and is always - /// flagged. - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use subtract_left() or block_discontinuity::flag_heads() instead. - /// \tparam ItemsPerThread [inferred] the number of items to be processed by - /// each thread. - /// \tparam Flag [inferred] the flag type. - /// \tparam FlagOp [inferred] type of binary function used for flagging. - /// - /// \param [out] head_flags array that contains the head flags. - /// \param [in] input array that data is loaded from. - /// \param [in] flag_op binary operation function object that will be used for flagging. - /// The signature of the function should be equivalent to the following: - /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. - /// The signature does not need to have const &, but function object - /// must not modify the objects passed to it. - /// \param [in] storage reference to a temporary storage object of type storage_type. - /// - /// \par Storage reuse - /// Synchronization barrier should be placed before \p storage is reused - /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). - /// - /// \par Example. - /// \code{.cpp} - /// __global__ void example_kernel(...) - /// { - /// // specialize discontinuity for int and a block of 128 threads - /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; - /// // allocate storage in shared memory - /// __shared__ block_adjacent_difference_int::storage_type storage; - /// - /// // segment of consecutive items to be used - /// int input[8]; - /// ... - /// int head_flags[8]; - /// block_adjacent_difference_int b_discontinuity; - /// using flag_op_type = typename rocprim::greater; - /// b_discontinuity.flag_heads(head_flags, input, flag_op_type(), storage); - /// ... - /// } - /// \endcode - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use subtract_left or block_discontinuity.flag_heads instead.")]] - ROCPRIM_DEVICE ROCPRIM_INLINE - void flag_heads(Flag (&head_flags)[ItemsPerThread], - const T (&input)[ItemsPerThread], - FlagOp flag_op, - storage_type& storage) - { - static constexpr auto as_flags = true; - static constexpr auto reversed = true; - static constexpr auto with_predecessor = false; - base_type::template apply_left( - input, head_flags, flag_op, input[0] /* predecessor */, storage.get().left); - } - - /// \overload - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use subtract_left() or block_discontinuity::flag_heads() instead. - /// This overload does not take a reference to temporary storage, instead it is declared as - /// part of the function itself. Note that this does NOT decrease the shared memory requirements - /// of a kernel using this function. - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use subtract_left or block_discontinuity.flag_heads instead.")]] - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void flag_heads(Flag (&head_flags)[ItemsPerThread], - const T (&input)[ItemsPerThread], - FlagOp flag_op) - { - ROCPRIM_SHARED_MEMORY storage_type storage; - flag_heads(head_flags, input, flag_op, storage); - } - - /// \brief Tags \p head_flags that indicate discontinuities between items partitioned - /// across the thread block, where the first item of the first thread is compared against - /// a \p tile_predecessor_item. - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use subtract_left() or block_discontinuity::flag_heads() instead. - /// - /// \tparam ItemsPerThread [inferred] the number of items to be processed by - /// each thread. - /// \tparam Flag [inferred] the flag type. - /// \tparam FlagOp [inferred] type of binary function used for flagging. - /// - /// \param [out] head_flags array that contains the head flags. - /// \param [in] tile_predecessor_item first tile item from thread to be compared - /// against. - /// \param [in] input array that data is loaded from. - /// \param [in] flag_op binary operation function object that will be used for flagging. - /// The signature of the function should be equivalent to the following: - /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. - /// The signature does not need to have const &, but function object - /// must not modify the objects passed to it. - /// \param [in] storage reference to a temporary storage object of type storage_type. - /// - /// \par Storage reuse - /// Synchronization barrier should be placed before \p storage is reused - /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). - /// - /// \par Example. - /// \code{.cpp} - /// __global__ void example_kernel(...) - /// { - /// // specialize discontinuity for int and a block of 128 threads - /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; - /// // allocate storage in shared memory - /// __shared__ block_adjacent_difference_int::storage_type storage; - /// - /// // segment of consecutive items to be used - /// int input[8]; - /// int tile_item = 0; - /// if (threadIdx.x == 0) - /// { - /// tile_item = ... - /// } - /// ... - /// int head_flags[8]; - /// block_adjacent_difference_int b_discontinuity; - /// using flag_op_type = typename rocprim::greater; - /// b_discontinuity.flag_heads(head_flags, tile_item, input, flag_op_type(), - /// storage); - /// ... - /// } - /// \endcode - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use subtract_left or block_discontinuity.flag_heads instead.")]] - ROCPRIM_DEVICE ROCPRIM_INLINE - void flag_heads(Flag (&head_flags)[ItemsPerThread], - T tile_predecessor_item, - const T (&input)[ItemsPerThread], - FlagOp flag_op, - storage_type& storage) - { - static constexpr auto as_flags = true; - static constexpr auto reversed = true; - static constexpr auto with_predecessor = true; - base_type::template apply_left( - input, head_flags, flag_op, tile_predecessor_item, storage.get().left); - } - - /// \overload - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use subtract_left() or block_discontinuity::flag_heads() instead. - /// - /// This overload does not accept a reference to temporary storage, instead it is declared as - /// part of the function itself. Note that this does NOT decrease the shared memory requirements - /// of a kernel using this function. - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use subtract_left or block_discontinuity.flag_heads instead.")]] - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void flag_heads(Flag (&head_flags)[ItemsPerThread], - T tile_predecessor_item, - const T (&input)[ItemsPerThread], - FlagOp flag_op) - { - ROCPRIM_SHARED_MEMORY storage_type storage; - flag_heads(head_flags, tile_predecessor_item, input, flag_op, storage); - } - - /// \brief Tags \p tail_flags that indicate discontinuities between items partitioned - /// across the thread block, where the last item has no reference and is always - /// flagged. - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use subtract_right() or block_discontinuity::flag_tails() instead. - /// - /// \tparam ItemsPerThread [inferred] the number of items to be processed by - /// each thread. - /// \tparam Flag [inferred] the flag type. - /// \tparam FlagOp [inferred] type of binary function used for flagging. - /// - /// \param [out] tail_flags array that contains the tail flags. - /// \param [in] input array that data is loaded from. - /// \param [in] flag_op binary operation function object that will be used for flagging. - /// The signature of the function should be equivalent to the following: - /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. - /// The signature does not need to have const &, but function object - /// must not modify the objects passed to it. - /// \param [in] storage reference to a temporary storage object of type storage_type. - /// - /// \par Storage reuse - /// Synchronization barrier should be placed before \p storage is reused - /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). - /// - /// \par Example. - /// \code{.cpp} - /// __global__ void example_kernel(...) - /// { - /// // specialize discontinuity for int and a block of 128 threads - /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; - /// // allocate storage in shared memory - /// __shared__ block_adjacent_difference_int::storage_type storage; - /// - /// // segment of consecutive items to be used - /// int input[8]; - /// ... - /// int tail_flags[8]; - /// block_adjacent_difference_int b_discontinuity; - /// using flag_op_type = typename rocprim::greater; - /// b_discontinuity.flag_tails(tail_flags, input, flag_op_type(), storage); - /// ... - /// } - /// \endcode - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use subtract_right or block_discontinuity.flag_tails instead.")]] - ROCPRIM_DEVICE ROCPRIM_INLINE - void flag_tails(Flag (&tail_flags)[ItemsPerThread], - const T (&input)[ItemsPerThread], - FlagOp flag_op, - storage_type& storage) - { - static constexpr auto as_flags = true; - static constexpr auto reversed = true; - static constexpr auto with_successor = false; - base_type::template apply_right( - input, tail_flags, flag_op, input[0] /* successor */, storage.get().right); - } - - /// \overload - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use subtract_right() or block_discontinuity::flag_tails() instead. - /// - /// This overload does not accept a reference to temporary storage, instead it is declared as - /// part of the function itself. Note that this does NOT decrease the shared memory requirements - /// of a kernel using this function. - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use subtract_right or block_discontinuity.flag_tails instead.")]] - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void flag_tails(Flag (&tail_flags)[ItemsPerThread], - const T (&input)[ItemsPerThread], - FlagOp flag_op) - { - ROCPRIM_SHARED_MEMORY storage_type storage; - flag_tails(tail_flags, input, flag_op, storage); - } - - /// \brief Tags \p tail_flags that indicate discontinuities between items partitioned - /// across the thread block, where the last item of the last thread is compared against - /// a \p tile_successor_item. - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use subtract_right() or block_discontinuity::flag_tails() instead. - /// - /// \tparam ItemsPerThread [inferred] the number of items to be processed by - /// each thread. - /// \tparam Flag [inferred] the flag type. - /// \tparam FlagOp [inferred] type of binary function used for flagging. - /// - /// \param [out] tail_flags array that contains the tail flags. - /// \param [in] tile_successor_item last tile item from thread to be compared - /// against. - /// \param [in] input array that data is loaded from. - /// \param [in] flag_op binary operation function object that will be used for flagging. - /// The signature of the function should be equivalent to the following: - /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. - /// The signature does not need to have const &, but function object - /// must not modify the objects passed to it. - /// \param [in] storage reference to a temporary storage object of type storage_type. - /// - /// \par Storage reuse - /// Synchronization barrier should be placed before \p storage is reused - /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). - /// - /// \par Example. - /// \code{.cpp} - /// __global__ void example_kernel(...) - /// { - /// // specialize discontinuity for int and a block of 128 threads - /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; - /// // allocate storage in shared memory - /// __shared__ block_adjacent_difference_int::storage_type storage; - /// - /// // segment of consecutive items to be used - /// int input[8]; - /// int tile_item = 0; - /// if (threadIdx.x == 0) - /// { - /// tile_item = ... - /// } - /// ... - /// int tail_flags[8]; - /// block_adjacent_difference_int b_discontinuity; - /// using flag_op_type = typename rocprim::greater; - /// b_discontinuity.flag_tails(tail_flags, tile_item, input, flag_op_type(), - /// storage); - /// ... - /// } - /// \endcode - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use subtract_right or block_discontinuity.flag_tails instead.")]] - ROCPRIM_DEVICE ROCPRIM_INLINE - void flag_tails(Flag (&tail_flags)[ItemsPerThread], - T tile_successor_item, - const T (&input)[ItemsPerThread], - FlagOp flag_op, - storage_type& storage) - { - static constexpr auto as_flags = true; - static constexpr auto reversed = true; - static constexpr auto with_successor = true; - base_type::template apply_right( - input, tail_flags, flag_op, tile_successor_item, storage.get().right); - } - - /// \overload - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use subtract_right() or block_discontinuity::flag_tails() instead. - /// - /// This overload does not accept a reference to temporary storage, instead it is declared as - /// part of the function itself. Note that this does NOT decrease the shared memory requirements - /// of a kernel using this function. - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use subtract_right or block_discontinuity.flag_tails instead.")]] - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void flag_tails(Flag (&tail_flags)[ItemsPerThread], - T tile_successor_item, - const T (&input)[ItemsPerThread], - FlagOp flag_op) - { - ROCPRIM_SHARED_MEMORY storage_type storage; - flag_tails(tail_flags, tile_successor_item, input, flag_op, storage); - } - - /// \brief Tags both \p head_flags and\p tail_flags that indicate discontinuities - /// between items partitioned across the thread block. - /// - /// \tparam ItemsPerThread [inferred] the number of items to be processed by - /// each thread. - /// \tparam Flag [inferred] the flag type. - /// \tparam FlagOp [inferred] type of binary function used for flagging. - /// - /// \param [out] head_flags array that contains the head flags. - /// \param [out] tail_flags array that contains the tail flags. - /// \param [in] input array that data is loaded from. - /// \param [in] flag_op binary operation function object that will be used for flagging. - /// The signature of the function should be equivalent to the following: - /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. - /// The signature does not need to have const &, but function object - /// must not modify the objects passed to it. - /// \param [in] storage reference to a temporary storage object of type storage_type. - /// - /// \par Storage reuse - /// Synchronization barrier should be placed before \p storage is reused - /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). - /// - /// \par Example. - /// \code{.cpp} - /// __global__ void example_kernel(...) - /// { - /// // specialize discontinuity for int and a block of 128 threads - /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; - /// // allocate storage in shared memory - /// __shared__ block_adjacent_difference_int::storage_type storage; - /// - /// // segment of consecutive items to be used - /// int input[8]; - /// ... - /// int head_flags[8]; - /// int tail_flags[8]; - /// block_adjacent_difference_int b_discontinuity; - /// using flag_op_type = typename rocprim::greater; - /// b_discontinuity.flag_heads_and_tails(head_flags, tail_flags, input, - /// flag_op_type(), storage); - /// ... - /// } - /// \endcode - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use block_discontinuity.flag_heads_and_tails instead.")]] - ROCPRIM_DEVICE ROCPRIM_INLINE - void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], - Flag (&tail_flags)[ItemsPerThread], - const T (&input)[ItemsPerThread], - FlagOp flag_op, - storage_type& storage) - { - static constexpr auto as_flags = true; - static constexpr auto reversed = true; - static constexpr auto with_predecessor = false; - static constexpr auto with_successor = false; - - // Copy items in case head_flags is aliased with input - T items[ItemsPerThread]; - - ROCPRIM_UNROLL - for(unsigned int i = 0; i < ItemsPerThread; ++i) { - items[i] = input[i]; - } - - base_type::template apply_left( - items, head_flags, flag_op, items[0] /*predecessor*/, storage.get().left); - - base_type::template apply_right( - items, tail_flags, flag_op, items[0] /*successor*/, storage.get().right); - } - - /// \overload - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use block_discontinuity::flag_heads_and_tails() instead. - /// - /// This overload does not accept a reference to temporary storage, instead it is declared as - /// part of the function itself. Note that this does NOT decrease the shared memory requirements - /// of a kernel using this function. - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use block_discontinuity.flag_heads_and_tails instead.")]] - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], - Flag (&tail_flags)[ItemsPerThread], - const T (&input)[ItemsPerThread], - FlagOp flag_op) - { - ROCPRIM_SHARED_MEMORY storage_type storage; - flag_heads_and_tails(head_flags, tail_flags, input, flag_op, storage); - } - - /// \brief Tags both \p head_flags and\p tail_flags that indicate discontinuities - /// between items partitioned across the thread block, where the last item of the - /// last thread is compared against a \p tile_successor_item. - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use block_discontinuity::flag_heads_and_tails() instead. - /// - /// \tparam ItemsPerThread [inferred] the number of items to be processed by - /// each thread. - /// \tparam Flag [inferred] the flag type. - /// \tparam FlagOp [inferred] type of binary function used for flagging. - /// - /// \param [out] head_flags array that contains the head flags. - /// \param [out] tail_flags array that contains the tail flags. - /// \param [in] tile_successor_item last tile item from thread to be compared - /// against. - /// \param [in] input array that data is loaded from. - /// \param [in] flag_op binary operation function object that will be used for flagging. - /// The signature of the function should be equivalent to the following: - /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. - /// The signature does not need to have const &, but function object - /// must not modify the objects passed to it. - /// \param [in] storage reference to a temporary storage object of type storage_type. - /// - /// \par Storage reuse - /// Synchronization barrier should be placed before \p storage is reused - /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). - /// - /// \par Example. - /// \code{.cpp} - /// __global__ void example_kernel(...) - /// { - /// // specialize discontinuity for int and a block of 128 threads - /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; - /// // allocate storage in shared memory - /// __shared__ block_adjacent_difference_int::storage_type storage; - /// - /// // segment of consecutive items to be used - /// int input[8]; - /// int tile_item = 0; - /// if (threadIdx.x == 0) - /// { - /// tile_item = ... - /// } - /// ... - /// int head_flags[8]; - /// int tail_flags[8]; - /// block_adjacent_difference_int b_discontinuity; - /// using flag_op_type = typename rocprim::greater; - /// b_discontinuity.flag_heads_and_tails(head_flags, tail_flags, tile_item, - /// input, flag_op_type(), - /// storage); - /// ... - /// } - /// \endcode - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use block_discontinuity.flag_heads_and_tails instead.")]] - ROCPRIM_DEVICE ROCPRIM_INLINE - void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], - Flag (&tail_flags)[ItemsPerThread], - T tile_successor_item, - const T (&input)[ItemsPerThread], - FlagOp flag_op, - storage_type& storage) - { - static constexpr auto as_flags = true; - static constexpr auto reversed = true; - static constexpr auto with_predecessor = false; - static constexpr auto with_successor = true; - - // Copy items in case head_flags is aliased with input - T items[ItemsPerThread]; - - ROCPRIM_UNROLL - for(unsigned int i = 0; i < ItemsPerThread; ++i) { - items[i] = input[i]; - } - - base_type::template apply_left( - items, head_flags, flag_op, items[0] /*predecessor*/, storage.get().left); - - base_type::template apply_right( - items, tail_flags, flag_op, tile_successor_item, storage.get().right); - } - - /// \overload - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use block_discontinuity::flag_heads_and_tails() instead. - /// - /// This overload does not accept a reference to temporary storage, instead it is declared as - /// part of the function itself. Note that this does NOT decrease the shared memory requirements - /// of a kernel using this function. - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use block_discontinuity.flag_heads_and_tails instead.")]] - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], - Flag (&tail_flags)[ItemsPerThread], - T tile_successor_item, - const T (&input)[ItemsPerThread], - FlagOp flag_op) - { - ROCPRIM_SHARED_MEMORY storage_type storage; - flag_heads_and_tails(head_flags, tail_flags, tile_successor_item, input, flag_op, storage); - } - - /// \brief Tags both \p head_flags and\p tail_flags that indicate discontinuities - /// between items partitioned across the thread block, where the first item of the - /// first thread is compared against a \p tile_predecessor_item. - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use block_discontinuity::flag_heads_and_tails() instead. - /// - /// \tparam ItemsPerThread [inferred] the number of items to be processed by - /// each thread. - /// \tparam Flag [inferred] the flag type. - /// \tparam FlagOp [inferred] type of binary function used for flagging. - /// - /// \param [out] head_flags array that contains the head flags. - /// \param [in] tile_predecessor_item first tile item from thread to be compared - /// against. - /// \param [out] tail_flags array that contains the tail flags. - /// \param [in] input array that data is loaded from. - /// \param [in] flag_op binary operation function object that will be used for flagging. - /// The signature of the function should be equivalent to the following: - /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. - /// The signature does not need to have const &, but function object - /// must not modify the objects passed to it. - /// \param [in] storage reference to a temporary storage object of type storage_type. - /// - /// \par Storage reuse - /// Synchronization barrier should be placed before \p storage is reused - /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). - /// - /// \par Example. - /// \code{.cpp} - /// __global__ void example_kernel(...) - /// { - /// // specialize discontinuity for int and a block of 128 threads - /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; - /// // allocate storage in shared memory - /// __shared__ block_adjacent_difference_int::storage_type storage; - /// - /// // segment of consecutive items to be used - /// int input[8]; - /// int tile_item = 0; - /// if (threadIdx.x == 0) - /// { - /// tile_item = ... - /// } - /// ... - /// int head_flags[8]; - /// int tail_flags[8]; - /// block_adjacent_difference_int b_discontinuity; - /// using flag_op_type = typename rocprim::greater; - /// b_discontinuity.flag_heads_and_tails(head_flags, tile_item, tail_flags, - /// input, flag_op_type(), - /// storage); - /// ... - /// } - /// \endcode - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use block_discontinuity.flag_heads_and_tails instead.")]] - ROCPRIM_DEVICE ROCPRIM_INLINE - void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], - T tile_predecessor_item, - Flag (&tail_flags)[ItemsPerThread], - const T (&input)[ItemsPerThread], - FlagOp flag_op, - storage_type& storage) - { - static constexpr auto as_flags = true; - static constexpr auto reversed = true; - static constexpr auto with_predecessor = true; - static constexpr auto with_successor = false; - - // Copy items in case head_flags is aliased with input - T items[ItemsPerThread]; - - ROCPRIM_UNROLL - for(unsigned int i = 0; i < ItemsPerThread; ++i) { - items[i] = input[i]; - } - - base_type::template apply_left( - items, head_flags, flag_op, tile_predecessor_item, storage.get().left); - - base_type::template apply_right( - items, tail_flags, flag_op, items[0] /*successor*/, storage.get().right); - } - - /// \overload - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use block_discontinuity::flag_heads_and_tails() instead. - /// - /// This overload does not accept a reference to temporary storage, instead it is declared as - /// part of the function itself. Note that this does NOT decrease the shared memory requirements - /// of a kernel using this function. - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use block_discontinuity.flag_heads_and_tails instead.")]] - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], - T tile_predecessor_item, - Flag (&tail_flags)[ItemsPerThread], - const T (&input)[ItemsPerThread], - FlagOp flag_op) - { - ROCPRIM_SHARED_MEMORY storage_type storage; - flag_heads_and_tails(head_flags, tile_predecessor_item, tail_flags, input, flag_op, storage); - } - - /// \brief Tags both \p head_flags and\p tail_flags that indicate discontinuities - /// between items partitioned across the thread block, where the first and last items of - /// the first and last thread is compared against a \p tile_predecessor_item and - /// a \p tile_successor_item. - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use block_discontinuity::flag_heads_and_tails() instead. - /// - /// \tparam ItemsPerThread [inferred] the number of items to be processed by - /// each thread. - /// \tparam Flag [inferred] the flag type. - /// \tparam FlagOp [inferred] type of binary function used for flagging. - /// - /// \param [out] head_flags array that contains the head flags. - /// \param [in] tile_predecessor_item first tile item from thread to be compared - /// against. - /// \param [out] tail_flags array that contains the tail flags. - /// \param [in] tile_successor_item last tile item from thread to be compared - /// against. - /// \param [in] input array that data is loaded from. - /// \param [in] flag_op binary operation function object that will be used for flagging. - /// The signature of the function should be equivalent to the following: - /// bool f(const T &a, const T &b); or bool (const T& a, const T& b, unsigned int b_index);. - /// The signature does not need to have const &, but function object - /// must not modify the objects passed to it. - /// \param [in] storage reference to a temporary storage object of type storage_type. - /// - /// \par Storage reuse - /// Synchronization barrier should be placed before \p storage is reused - /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). - /// - /// \par Example. - /// \code{.cpp} - /// __global__ void example_kernel(...) - /// { - /// // specialize discontinuity for int and a block of 128 threads - /// using block_adjacent_difference_int = rocprim::block_adjacent_difference; - /// // allocate storage in shared memory - /// __shared__ block_adjacent_difference_int::storage_type storage; - /// - /// // segment of consecutive items to be used - /// int input[8]; - /// int tile_predecessor_item = 0; - /// int tile_successor_item = 0; - /// if (threadIdx.x == 0) - /// { - /// tile_predecessor_item = ... - /// tile_successor_item = ... - /// } - /// ... - /// int head_flags[8]; - /// int tail_flags[8]; - /// block_adjacent_difference_int b_discontinuity; - /// using flag_op_type = typename rocprim::greater; - /// b_discontinuity.flag_heads_and_tails(head_flags, tile_predecessor_item, - /// tail_flags, tile_successor_item, - /// input, flag_op_type(), - /// storage); - /// ... - /// } - /// \endcode - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use block_discontinuity.flag_heads_and_tails instead.")]] - ROCPRIM_DEVICE ROCPRIM_INLINE - void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], - T tile_predecessor_item, - Flag (&tail_flags)[ItemsPerThread], - T tile_successor_item, - const T (&input)[ItemsPerThread], - FlagOp flag_op, - storage_type& storage) - { - static constexpr auto as_flags = true; - static constexpr auto reversed = true; - static constexpr auto with_predecessor = true; - static constexpr auto with_successor = true; - - // Copy items in case head_flags is aliased with input - T items[ItemsPerThread]; - - ROCPRIM_UNROLL - for(unsigned int i = 0; i < ItemsPerThread; ++i) { - items[i] = input[i]; - } - - base_type::template apply_left( - items, head_flags, flag_op, tile_predecessor_item, storage.get().left); - - base_type::template apply_right( - items, tail_flags, flag_op, tile_successor_item, storage.get().right); - } - - /// \overload - /// \deprecated The flags API of block_adjacent_difference is deprecated, - /// use block_discontinuity::flag_heads_and_tails() instead. - /// - /// This overload does not accept a reference to temporary storage, instead it is declared as - /// part of the function itself. Note that this does NOT decrease the shared memory requirements - /// of a kernel using this function. - template - [[deprecated("The flags API of block_adjacent_difference is deprecated." - "Use block_discontinuity.flag_heads_and_tails instead.")]] - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE - void flag_heads_and_tails(Flag (&head_flags)[ItemsPerThread], - T tile_predecessor_item, - Flag (&tail_flags)[ItemsPerThread], - T tile_successor_item, - const T (&input)[ItemsPerThread], - FlagOp flag_op) - { - ROCPRIM_SHARED_MEMORY storage_type storage; - flag_heads_and_tails( - head_flags, tile_predecessor_item, tail_flags, tile_successor_item, - input, flag_op, storage - ); - } - /// \brief Apply a function to each consecutive pair of elements partitioned across threads in /// the block and write the output to the position of the left item. /// diff --git a/rocprim/include/rocprim/block/block_exchange.hpp b/rocprim/include/rocprim/block/block_exchange.hpp index 7692d3e54..51b2f1cdc 100644 --- a/rocprim/include/rocprim/block/block_exchange.hpp +++ b/rocprim/include/rocprim/block/block_exchange.hpp @@ -77,20 +77,20 @@ BEGIN_ROCPRIM_NAMESPACE /// } /// \endcode /// \endparblock -template< - class T, - unsigned int BlockSizeX, - unsigned int ItemsPerThread, - unsigned int BlockSizeY = 1, - unsigned int BlockSizeZ = 1, - block_padding_hint PaddingHint = block_padding_hint::avoid_conflicts -> +template class block_exchange { static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; // Select warp size - static constexpr unsigned int warp_size - = detail::get_min_warp_size(BlockSize, ::rocprim::arch::wavefront::min_size()); + static constexpr unsigned int warp_size = ::rocprim::detail::get_min_warp_size( + BlockSize, ::rocprim::arch::wavefront::size_from_target()); // Number of warps in block static constexpr unsigned int warps_no = ::rocprim::detail::ceiling_div(BlockSize, warp_size); static constexpr unsigned int banks_no = ::rocprim::detail::get_lds_banks_no(); @@ -657,23 +657,26 @@ class block_exchange /// ... /// } /// \endcode - template + template(), + class U, + class Offset> ROCPRIM_DEVICE ROCPRIM_INLINE void scatter_to_warp_striped(const T (&input)[ItemsPerThread], U (&output)[ItemsPerThread], const Offset (&ranks)[ItemsPerThread], storage_type& storage) { - static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= arch::wavefront::max_size(), - "WarpSize must be a power of two and equal or less" + static_assert(detail::is_power_of_two(VirtualWaveSize) + && VirtualWaveSize <= arch::wavefront::max_size(), + "VirtualWaveSize must be a power of two and equal or less" "than the size of hardware warp."); - assert(WarpSize <= arch::wavefront::size()); + assert(VirtualWaveSize <= arch::wavefront::size()); const unsigned int flat_id = ::rocprim::flat_block_thread_id(); - const unsigned int thread_id = detail::logical_lane_id(); - const unsigned int warp_id = flat_id / WarpSize; - const unsigned int warp_offset = warp_id * WarpSize * ItemsPerThread; + const unsigned int thread_id = detail::logical_lane_id(); + const unsigned int warp_id = flat_id / VirtualWaveSize; + const unsigned int warp_offset = warp_id * VirtualWaveSize * ItemsPerThread; const unsigned int thread_offset = thread_id + warp_offset; ROCPRIM_UNROLL @@ -690,7 +693,7 @@ class block_exchange ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { - output[i] = storage_buffer[index(thread_offset + i * WarpSize)]; + output[i] = storage_buffer[index(thread_offset + i * VirtualWaveSize)]; } } @@ -884,6 +887,121 @@ class block_exchange } }; +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +template +class block_exchange +{ +private: + using block_exchange_wave32 = block_exchange; + using block_exchange_wave64 = block_exchange; + using dispatch + = ::rocprim::detail::dispatch_wave_size; + +public: + using storage_type = typename dispatch::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto blocked_to_striped(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.blocked_to_striped(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto striped_to_blocked(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.striped_to_blocked(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto blocked_to_warp_striped(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.blocked_to_warp_striped(args...); }, + args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto warp_striped_to_blocked(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.warp_striped_to_blocked(args...); }, + args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto scatter_to_blocked(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.scatter_to_blocked(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto gather_from_striped(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.gather_from_striped(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto scatter_to_striped(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.scatter_to_striped(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto scatter_to_warp_striped(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.scatter_to_warp_striped(args...); }, + args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto scatter_to_striped_guarded(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.scatter_to_striped_guarded(args...); }, + args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto scatter_to_striped_flagged(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.scatter_to_striped_flagged(args...); }, + args...); + } +}; + +#endif + END_ROCPRIM_NAMESPACE /// @} diff --git a/rocprim/include/rocprim/block/block_load_func.hpp b/rocprim/include/rocprim/block/block_load_func.hpp index 61963f26a..0fae56776 100644 --- a/rocprim/include/rocprim/block/block_load_func.hpp +++ b/rocprim/include/rocprim/block/block_load_func.hpp @@ -29,6 +29,8 @@ #include "../types.hpp" #include "rocprim/intrinsics/arch.hpp" +#include "../thread/thread_load.hpp" + /// \addtogroup blockmodule /// @{ @@ -347,18 +349,18 @@ void block_load_direct_striped(unsigned int flat_id, /// across the thread block. /// /// \ingroup blockmodule_warp_load_functions -/// The warp-striped arrangement is assumed to be (\p WarpSize * \p ItemsPerThread) items +/// The warp-striped arrangement is assumed to be (\p VirtualWaveSize * \p ItemsPerThread) items /// across a thread block. Each thread uses a \p flat_id to load a range of /// \p ItemsPerThread into \p items. /// -/// * The number of threads in the block must be a multiple of \p WarpSize. -/// * The default \p WarpSize is a hardware warpsize and is an optimal value. -/// * \p WarpSize must be a power of two and equal or less than the size of +/// * The number of threads in the block must be a multiple of \p VirtualWaveSize. +/// * The default \p VirtualWaveSize is a hardware warpsize and is an optimal value. +/// * \p VirtualWaveSize must be a power of two and equal or less than the size of /// hardware warp. -/// * Using \p WarpSize smaller than hardware warpsize could result in lower +/// * Using \p VirtualWaveSize smaller than hardware warpsize could result in lower /// performance. /// -/// \tparam WarpSize [optional] the number of threads in a warp +/// \tparam VirtualWaveSize [optional] the number of threads in a warp /// \tparam InputIterator [inferred] an iterator type for input (can be a simple /// pointer /// \tparam T [inferred] the data type @@ -368,7 +370,7 @@ void block_load_direct_striped(unsigned int flat_id, /// \param flat_id a local flat 1D thread id in a block (tile) for the calling thread /// \param block_input the input iterator from the thread block to load from /// \param items array that data is loaded to -template @@ -377,20 +379,21 @@ void block_load_direct_warp_striped(unsigned int flat_id, InputIterator block_input, T (&items)[ItemsPerThread]) { - static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= arch::wavefront::max_size(), - "WarpSize must be a power of two and equal or less" + static_assert(detail::is_power_of_two(VirtualWaveSize) + && VirtualWaveSize <= arch::wavefront::max_size(), + "VirtualWaveSize must be a power of two and equal or less" "than the size of hardware warp."); - assert(WarpSize <= arch::wavefront::size()); + assert(VirtualWaveSize <= arch::wavefront::size()); - unsigned int thread_id = detail::logical_lane_id(); - unsigned int warp_id = flat_id / WarpSize; - unsigned int warp_offset = warp_id * WarpSize * ItemsPerThread; + unsigned int thread_id = detail::logical_lane_id(); + unsigned int warp_id = flat_id / VirtualWaveSize; + unsigned int warp_offset = warp_id * VirtualWaveSize * ItemsPerThread; InputIterator thread_iter = block_input + thread_id + warp_offset; ROCPRIM_UNROLL for (unsigned int item = 0; item < ItemsPerThread; item++) { - items[item] = thread_iter[item * WarpSize]; + items[item] = thread_iter[item * VirtualWaveSize]; } } @@ -398,18 +401,18 @@ void block_load_direct_warp_striped(unsigned int flat_id, /// across the thread block, which is guarded by range \p valid. /// /// \ingroup blockmodule_warp_load_functions -/// The warp-striped arrangement is assumed to be (\p WarpSize * \p ItemsPerThread) items +/// The warp-striped arrangement is assumed to be (\p VirtualWaveSize * \p ItemsPerThread) items /// across a thread block. Each thread uses a \p flat_id to load a range of /// \p ItemsPerThread into \p items. /// -/// * The number of threads in the block must be a multiple of \p WarpSize. -/// * The default \p WarpSize is a hardware warpsize and is an optimal value. -/// * \p WarpSize must be a power of two and equal or less than the size of +/// * The number of threads in the block must be a multiple of \p VirtualWaveSize. +/// * The default \p VirtualWaveSize is a hardware warpsize and is an optimal value. +/// * \p VirtualWaveSize must be a power of two and equal or less than the size of /// hardware warp. -/// * Using \p WarpSize smaller than hardware warpsize could result in lower +/// * Using \p VirtualWaveSize smaller than hardware warpsize could result in lower /// performance. /// -/// \tparam WarpSize [optional] the number of threads in a warp +/// \tparam VirtualWaveSize [optional] the number of threads in a warp /// \tparam InputIterator [inferred] an iterator type for input (can be a simple /// pointer /// \tparam T [inferred] the data type @@ -420,7 +423,7 @@ void block_load_direct_warp_striped(unsigned int flat_id, /// \param block_input the input iterator from the thread block to load from /// \param items array that data is loaded to /// \param valid maximum range of valid numbers to load -template @@ -430,20 +433,21 @@ void block_load_direct_warp_striped(unsigned int flat_id, T (&items)[ItemsPerThread], unsigned int valid) { - static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= arch::wavefront::max_size(), - "WarpSize must be a power of two and equal or less" + static_assert(detail::is_power_of_two(VirtualWaveSize) + && VirtualWaveSize <= arch::wavefront::max_size(), + "VirtualWaveSize must be a power of two and equal or less" "than the size of hardware warp."); - assert(WarpSize <= arch::wavefront::size()); + assert(VirtualWaveSize <= arch::wavefront::size()); - unsigned int thread_id = detail::logical_lane_id(); - unsigned int warp_id = flat_id / WarpSize; - unsigned int warp_offset = warp_id * WarpSize * ItemsPerThread; + unsigned int thread_id = detail::logical_lane_id(); + unsigned int warp_id = flat_id / VirtualWaveSize; + unsigned int warp_offset = warp_id * VirtualWaveSize * ItemsPerThread; InputIterator thread_iter = block_input + thread_id + warp_offset; ROCPRIM_UNROLL for (unsigned int item = 0; item < ItemsPerThread; item++) { - unsigned int offset = item * WarpSize; + unsigned int offset = item * VirtualWaveSize; if (warp_offset + thread_id + offset < valid) { items[item] = thread_iter[offset]; @@ -456,18 +460,18 @@ void block_load_direct_warp_striped(unsigned int flat_id, /// for out-of-bound elements. /// /// \ingroup blockmodule_warp_load_functions -/// The warp-striped arrangement is assumed to be (\p WarpSize * \p ItemsPerThread) items +/// The warp-striped arrangement is assumed to be (\p VirtualWaveSize * \p ItemsPerThread) items /// across a thread block. Each thread uses a \p flat_id to load a range of /// \p ItemsPerThread into \p items. /// -/// * The number of threads in the block must be a multiple of \p WarpSize. -/// * The default \p WarpSize is a hardware warpsize and is an optimal value. -/// * \p WarpSize must be a power of two and equal or less than the size of +/// * The number of threads in the block must be a multiple of \p VirtualWaveSize. +/// * The default \p VirtualWaveSize is a hardware warpsize and is an optimal value. +/// * \p VirtualWaveSize must be a power of two and equal or less than the size of /// hardware warp. -/// * Using \p WarpSize smaller than hardware warpsize could result in lower +/// * Using \p VirtualWaveSize smaller than hardware warpsize could result in lower /// performance. /// -/// \tparam WarpSize [optional] the number of threads in a warp +/// \tparam VirtualWaveSize [optional] the number of threads in a warp /// \tparam InputIterator [inferred] an iterator type for input (can be a simple /// pointer /// \tparam T [inferred] the data type @@ -480,7 +484,7 @@ void block_load_direct_warp_striped(unsigned int flat_id, /// \param items array that data is loaded to /// \param valid maximum range of valid numbers to load /// \param out_of_bounds default value assigned to out-of-bound items -template(flat_id, block_input, items, valid); + block_load_direct_warp_striped(flat_id, block_input, items, valid); +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto block_load_direct_blocked_cast(unsigned int flat_id, + T* block_input, + U (&items)[ItemsPerThread]) -> + typename std::enable_if::value + && (ItemsPerThread * sizeof(T)) % sizeof(V) == 0>::type +{ + static_assert(detail::is_power_of_two(VirtualWaveSize) + && VirtualWaveSize <= arch::wavefront::max_size(), + "VirtualWaveSize must be a power of two and equal or less" + "than the size of hardware warp."); + assert(VirtualWaveSize <= arch::wavefront::size()); + + constexpr unsigned int vectors_per_thread = (sizeof(T) * ItemsPerThread) / sizeof(V); + + const V* vector_ptr + = ::rocprim::detail::bit_cast(block_input) + flat_id * vectors_per_thread; + + ROCPRIM_UNROLL + for(unsigned int item = 0; item < vectors_per_thread; item++) + { + reinterpret_cast(items)[item] = thread_load(vector_ptr + item); + } +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto block_load_direct_blocked_cast(unsigned int flat_id, + T* block_input, + U (&items)[ItemsPerThread]) -> + typename std::enable_if::value + || (ItemsPerThread * sizeof(T)) % sizeof(V) != 0>::type +{ + block_load_direct_blocked(flat_id, block_input, items); } END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/block/block_radix_rank.hpp b/rocprim/include/rocprim/block/block_radix_rank.hpp index 82ca16d6f..610b8b77a 100644 --- a/rocprim/include/rocprim/block/block_radix_rank.hpp +++ b/rocprim/include/rocprim/block/block_radix_rank.hpp @@ -47,6 +47,8 @@ enum class block_radix_rank_algorithm match, /// \brief The default radix ranking algorithm. default_algorithm = basic, + /// \brief The placeholder for radix_sort default + default_for_radix_sort, }; namespace detail @@ -62,8 +64,10 @@ struct select_block_radix_rank_impl unsigned int RadixBits, unsigned int BlockSizeY, unsigned int BlockSizeZ, - block_padding_hint> - using type = block_radix_rank; + block_padding_hint, + arch::wavefront::target TargetWaveSize> + using type + = block_radix_rank; }; template<> @@ -73,19 +77,27 @@ struct select_block_radix_rank_impl unsigned int RadixBits, unsigned int BlockSizeY, unsigned int BlockSizeZ, - block_padding_hint> - using type = block_radix_rank; + block_padding_hint, + arch::wavefront::target TargetWaveSize> + using type + = block_radix_rank; }; template<> struct select_block_radix_rank_impl { - template - using type = block_radix_rank_match; + template + using type = block_radix_rank_match; }; } // namespace detail @@ -140,18 +152,20 @@ struct select_block_radix_rank_impl /// \endcode template + block_radix_rank_algorithm Algorithm = block_radix_rank_algorithm::default_algorithm, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1, + block_padding_hint PaddingHint = block_padding_hint::avoid_conflicts, + arch::wavefront::target TargetWaveSize = arch::wavefront::get_target(), + typename Enabled = void> class block_radix_rank #ifndef DOXYGEN_SHOULD_SKIP_THIS - : private detail::select_block_radix_rank_impl< - Algorithm>::template type + : private detail::select_block_radix_rank_impl:: + template type #endif { - using base_type = typename detail::select_block_radix_rank_impl< - Algorithm>::template type; + using base_type = typename detail::select_block_radix_rank_impl:: + template type; public: /// \brief The number of digits each thread will process. @@ -553,6 +567,63 @@ class block_radix_rank } }; +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template +class block_radix_rank +{ +private: + using block_radix_rank_wave32 = block_radix_rank; + using block_radix_rank_wave64 = block_radix_rank; + + using dispatch = detail::dispatch_wave_size; + +public: + static_assert(block_radix_rank_wave32::digits_per_thread + == block_radix_rank_wave64::digits_per_thread, + "digits_per_thread is not the same for wavefront size 32 and 64!"); + static constexpr unsigned int digits_per_thread = block_radix_rank_wave32::digits_per_thread; + + using storage_type = typename dispatch::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto rank_keys(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.rank_keys(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto rank_keys_desc(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.rank_keys_desc(args...); }, args...); + } +}; +#endif // DOXYGEN_SHOULD_SKIP_THIS + END_ROCPRIM_NAMESPACE /// @} diff --git a/rocprim/include/rocprim/block/block_radix_sort.hpp b/rocprim/include/rocprim/block/block_radix_sort.hpp index 60e6f4493..cc19acbc9 100644 --- a/rocprim/include/rocprim/block/block_radix_sort.hpp +++ b/rocprim/include/rocprim/block/block_radix_sort.hpp @@ -27,7 +27,6 @@ #include "../detail/various.hpp" #include "../functional.hpp" #include "../intrinsics/thread.hpp" -#include "../thread/radix_key_codec.hpp" #include "../types.hpp" #include "../warp/warp_exchange.hpp" @@ -99,27 +98,40 @@ BEGIN_ROCPRIM_NAMESPACE template + = block_radix_rank_algorithm::default_for_radix_sort, + block_padding_hint PaddingHint = block_padding_hint::lds_occupancy_bound, + arch::wavefront::target TargetWaveSize = arch::wavefront::get_target()> class block_radix_sort { - static_assert(RadixBitsPerPass > 0 && RadixBitsPerPass < 32, - "The RadixBitsPerPass should be larger than 0 and smaller than the size " + // TODO: somehow when prefer_match is true on SPIR-V, results + // are incorrect. Block radix rank works fine though... + static constexpr bool prefer_match = (BlockSizeX * BlockSizeY * BlockSizeZ) + % arch::wavefront::size_from_target() + == 0; + + static constexpr unsigned int radix_bits_per_pass = RadixBitsPerPass == 0 + ? (prefer_match ? 8 /* match */ + : 4 /* basic_memoize */) + : RadixBitsPerPass; + + static constexpr block_radix_rank_algorithm radix_rank_algorithm + = RadixRankAlgorithm == block_radix_rank_algorithm::default_for_radix_sort + ? (prefer_match ? block_radix_rank_algorithm::match + : block_radix_rank_algorithm::basic_memoize) + : RadixRankAlgorithm; + + static_assert(radix_bits_per_pass > 0 && radix_bits_per_pass < 32, + "The radix_bits_per_pass should be larger than 0 and smaller than the size " "of an unsigned int"); static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; static constexpr bool with_values = !std::is_same::value; - static constexpr bool warp_striped = RadixRankAlgorithm == block_radix_rank_algorithm::match; + static constexpr bool warp_striped = radix_rank_algorithm == block_radix_rank_algorithm::match; ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT( !warp_striped || (BlockSize % ::rocprim::arch::wavefront::min_size()) == 0, @@ -129,12 +141,29 @@ class block_radix_sort static constexpr bool is_key_and_value_aligned = alignof(Key) == alignof(Value) && sizeof(Key) == sizeof(Value); - using block_rank_type = ::rocprim:: - block_radix_rank; - using keys_exchange_type - = ::rocprim::block_exchange; - using values_exchange_type - = ::rocprim::block_exchange; + using block_rank_type = ::rocprim::block_radix_rank; + + using keys_exchange_type = ::rocprim::block_exchange; + + using values_exchange_type = ::rocprim::block_exchange; // Struct used for creating a raw_storage object for this primitive's temporary storage. union storage_type_ @@ -1066,9 +1095,6 @@ class block_radix_sort } private: - static constexpr bool use_warp_exchange - = ::rocprim::arch::wavefront::min_size() % ItemsPerThread == 0 && ItemsPerThread <= 4; - template ROCPRIM_DEVICE ROCPRIM_INLINE void blocked_to_warp_striped(Key (&keys)[ItemsPerThread], @@ -1077,7 +1103,7 @@ class block_radix_sort std::false_type) { keys_exchange_type().blocked_to_warp_striped(keys, keys, storage.get().keys_exchange); - if ROCPRIM_IF_CONSTEXPR(is_key_and_value_aligned) + if constexpr(is_key_and_value_aligned) { // If keys and values are aligned, then the LDS for both exchanges is // local per wave. We can relax the data dependency! @@ -1099,9 +1125,14 @@ class block_radix_sort storage_type& /* storage */, std::true_type) { - ::rocprim::warp_exchange{}.blocked_to_striped_shuffle(keys, keys); - ::rocprim::warp_exchange{}.blocked_to_striped_shuffle(values, - values); + constexpr int wave_size = ::rocprim::arch::wavefront::size_from_target(); + using keys_warp_exchange + = ::rocprim::warp_exchange; + using values_warp_exchange + = ::rocprim::warp_exchange; + + keys_warp_exchange{}.blocked_to_striped_shuffle(keys, keys); + values_warp_exchange{}.blocked_to_striped_shuffle(values, values); } template; + using key_codec + = decltype(::rocprim::traits::get().template radix_key_codec()); // 'rank_keys' may be invoked multiple times. We encode the key once and move the // encoded during the majority of sort to save on some compute. @@ -1129,11 +1161,11 @@ class block_radix_sort // If we're using warp striped radix rank but our input is in a blocked layout, we // can emulate the correct input through an exchange to a warp striped layout. - if ROCPRIM_IF_CONSTEXPR(TryEmulateWarpStriped && warp_striped && ItemsPerThread > 1) + if constexpr(TryEmulateWarpStriped && warp_striped && ItemsPerThread > 1) { // This appears to be slower with high large items per thread. constexpr bool use_warp_exchange - = ::rocprim::arch::wavefront::min_size() % ItemsPerThread == 0 + = arch::wavefront::size_from_target() % ItemsPerThread == 0 && ItemsPerThread <= 4; blocked_to_warp_striped(keys, values, @@ -1147,7 +1179,7 @@ class block_radix_sort unsigned int ranks[ItemsPerThread]; while(true) { - const int pass_bits = min(RadixBitsPerPass, end_bit - begin_bit); + const int pass_bits = min(radix_bits_per_pass, end_bit - begin_bit); block_rank_type().rank_keys( keys, @@ -1155,14 +1187,14 @@ class block_radix_sort storage.get().rank, [begin_bit, pass_bits, decomposer](const Key& key) mutable { return key_codec::extract_digit(key, begin_bit, pass_bits, decomposer); }); - begin_bit += RadixBitsPerPass; + begin_bit += radix_bits_per_pass; if(begin_bit >= end_bit) { break; } - if ROCPRIM_IF_CONSTEXPR(warp_striped) + if constexpr(warp_striped) { exchange_keys_warp_striped(storage, keys, ranks); exchange_values_warp_striped(storage, values, ranks); @@ -1177,7 +1209,7 @@ class block_radix_sort ::rocprim::syncthreads(); } - if ROCPRIM_IF_CONSTEXPR(ToStriped) + if constexpr(ToStriped) { exchange_to_striped_keys(storage, keys, ranks); exchange_to_striped_values(storage, values, ranks); @@ -1293,6 +1325,100 @@ class block_radix_sort } }; +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template +class block_radix_sort +{ + using block_radix_sort_wave32 = block_radix_sort; + using block_radix_sort_wave64 = block_radix_sort; + + using dispatch = detail::dispatch_wave_size; + +public: + using storage_type = typename dispatch::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.sort(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort_desc(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.sort_desc(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort_to_striped(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.sort_to_striped(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort_desc_to_striped(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.sort_desc_to_striped(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort_warp_striped_to_striped(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.sort_warp_striped_to_striped(args...); }, + args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto sort_desc_warp_striped_to_striped(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) + { impl.sort_desc_warp_striped_to_striped(args...); }, + args...); + } +}; +#endif // DOXYGEN_SHOULD_SKIP_THIS + END_ROCPRIM_NAMESPACE /// @} diff --git a/rocprim/include/rocprim/block/block_reduce.hpp b/rocprim/include/rocprim/block/block_reduce.hpp index c380dd4b8..f3451ee1f 100644 --- a/rocprim/include/rocprim/block/block_reduce.hpp +++ b/rocprim/include/rocprim/block/block_reduce.hpp @@ -62,22 +62,35 @@ struct select_block_reduce_impl; template<> struct select_block_reduce_impl { - template - using type = block_reduce_warp_reduce; + template + using type = block_reduce_warp_reduce; }; template<> struct select_block_reduce_impl { - template - using type = block_reduce_raking_reduce; + template + using type = block_reduce_raking_reduce; }; template<> struct select_block_reduce_impl { - template - using type = block_reduce_raking_reduce; + template + using type + = block_reduce_raking_reduce; }; @@ -129,19 +142,21 @@ struct select_block_reduce_impl +template class block_reduce #ifndef DOXYGEN_SHOULD_SKIP_THIS - : private detail::select_block_reduce_impl::template type + : private detail::select_block_reduce_impl< + Algorithm>::template type #endif { - using base_type = typename detail::select_block_reduce_impl::template type; + using base_type = typename detail::select_block_reduce_impl< + Algorithm>::template type; + public: /// \brief Struct used to allocate a temporary memory that is required for thread /// communication during operations provided by related parallel primitive. @@ -406,6 +421,46 @@ class block_reduce } }; +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template +class block_reduce +{ +private: + using block_reduce_wave32 = block_reduce; + using block_reduce_wave64 = block_reduce; + using dispatch = detail::dispatch_wave_size; + +public: + using storage_type = typename dispatch::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto reduce(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.reduce(args...); }, args...); + } +}; +#endif // DOXYGEN_SHOULD_SKIP_THIS + END_ROCPRIM_NAMESPACE /// @} diff --git a/rocprim/include/rocprim/block/block_scan.hpp b/rocprim/include/rocprim/block/block_scan.hpp index c03e7fbca..aa4592c42 100644 --- a/rocprim/include/rocprim/block/block_scan.hpp +++ b/rocprim/include/rocprim/block/block_scan.hpp @@ -60,20 +60,29 @@ struct select_block_scan_impl; template<> struct select_block_scan_impl { - template - using type = block_scan_warp_scan; + template + using type = block_scan_warp_scan; }; template<> struct select_block_scan_impl { - template + template // When BlockSize is less than hardware warp size block_scan_warp_scan performs better than // block_scan_reduce_then_scan by specializing for warps using type = typename std::conditional< - (BlockSizeX * BlockSizeY * BlockSizeZ <= ::rocprim::arch::wavefront::min_size()), - block_scan_warp_scan, - block_scan_reduce_then_scan>::type; + (BlockSizeX * BlockSizeY * BlockSizeZ + <= (arch::wavefront::size_from_target())), + block_scan_warp_scan, + block_scan_reduce_then_scan>::type; }; } // end namespace detail @@ -124,19 +133,21 @@ struct select_block_scan_impl /// } /// \endcode /// \endparblock -template< - class T, - unsigned int BlockSizeX, - block_scan_algorithm Algorithm = block_scan_algorithm::default_algorithm, - unsigned int BlockSizeY = 1, - unsigned int BlockSizeZ = 1 -> +template class block_scan #ifndef DOXYGEN_SHOULD_SKIP_THIS - : private detail::select_block_scan_impl::template type + : private detail::select_block_scan_impl< + Algorithm>::template type #endif { - using base_type = typename detail::select_block_scan_impl::template type; + using base_type = typename detail::select_block_scan_impl< + Algorithm>::template type; + public: /// \brief Struct used to allocate a temporary memory that is required for thread /// communication during operations provided by related parallel primitive. @@ -201,7 +212,7 @@ class block_scan storage_type& storage, BinaryFunction scan_op = BinaryFunction()) { - base_type::inclusive_scan(input, output, storage, scan_op); + base_type{}.inclusive_scan(input, output, storage, scan_op); } /// \overload @@ -225,7 +236,7 @@ class block_scan T& output, BinaryFunction scan_op = BinaryFunction()) { - base_type::inclusive_scan(input, output, scan_op); + base_type{}.inclusive_scan(input, output, scan_op); } /// \brief Performs inclusive scan and reduction across threads in a block. @@ -286,7 +297,7 @@ class block_scan storage_type& storage, BinaryFunction scan_op = BinaryFunction()) { - base_type::inclusive_scan(input, output, reduction, storage, scan_op); + base_type{}.inclusive_scan(input, output, reduction, storage, scan_op); } /// \overload @@ -312,7 +323,7 @@ class block_scan T& reduction, BinaryFunction scan_op = BinaryFunction()) { - base_type::inclusive_scan(input, output, reduction, scan_op); + base_type{}.inclusive_scan(input, output, reduction, scan_op); } /// \brief Performs inclusive scan across threads in a block, and uses @@ -402,7 +413,7 @@ class block_scan PrefixCallback& prefix_callback_op, BinaryFunction scan_op) { - base_type::inclusive_scan(input, output, storage, prefix_callback_op, scan_op); + base_type{}.inclusive_scan(input, output, storage, prefix_callback_op, scan_op); } /// \brief Performs inclusive scan across threads in a block. @@ -464,11 +475,11 @@ class block_scan { if(ItemsPerThread == 1) { - base_type::inclusive_scan(input[0], output[0], storage, scan_op); + base_type{}.inclusive_scan(input[0], output[0], storage, scan_op); } else { - base_type::inclusive_scan(input, output, storage, scan_op); + base_type{}.inclusive_scan(input, output, storage, scan_op); } } @@ -499,11 +510,11 @@ class block_scan { if(ItemsPerThread == 1) { - base_type::inclusive_scan(input[0], output[0], scan_op); + base_type{}.inclusive_scan(input[0], output[0], scan_op); } else { - base_type::inclusive_scan(input, output, scan_op); + base_type{}.inclusive_scan(input, output, scan_op); } } @@ -565,7 +576,7 @@ class block_scan storage_type& storage, BinaryFunction scan_op = BinaryFunction()) { - base_type::inclusive_scan(input, init, output, storage, scan_op); + base_type{}.inclusive_scan(input, init, output, storage, scan_op); } /// \overload @@ -595,7 +606,7 @@ class block_scan T (&output)[ItemsPerThread], BinaryFunction scan_op = BinaryFunction()) { - base_type::inclusive_scan(input, init, output, scan_op); + base_type{}.inclusive_scan(input, init, output, scan_op); } /// \brief Performs inclusive scan and reduction across threads in a block. @@ -661,11 +672,11 @@ class block_scan { if(ItemsPerThread == 1) { - base_type::inclusive_scan(input[0], output[0], reduction, storage, scan_op); + base_type{}.inclusive_scan(input[0], output[0], reduction, storage, scan_op); } else { - base_type::inclusive_scan(input, output, reduction, storage, scan_op); + base_type{}.inclusive_scan(input, output, reduction, storage, scan_op); } } @@ -698,11 +709,11 @@ class block_scan { if(ItemsPerThread == 1) { - base_type::inclusive_scan(input[0], output[0], reduction, scan_op); + base_type{}.inclusive_scan(input[0], output[0], reduction, scan_op); } else { - base_type::inclusive_scan(input, output, reduction, scan_op); + base_type{}.inclusive_scan(input, output, reduction, scan_op); } } @@ -715,7 +726,8 @@ class block_scan /// \param [in] input reference to an array containing thread input values. /// \param [in] init initial value to seed the inclusive scan. /// \param [out] output reference to a thread output array. May be aliased with \p input. - /// \param [out] reduction result of reducing of all \p input values in a block. + /// \param [out] reduction result of reducing of all \p input values in a block. This does + /// not include \p init. /// \param [in] storage reference to a temporary storage object of type storage_type. /// \param [in] scan_op binary operation function object that will be used for scan. /// The signature of the function should be equivalent to the following: @@ -769,7 +781,7 @@ class block_scan storage_type& storage, BinaryFunction scan_op = BinaryFunction()) { - base_type::inclusive_scan(input, init, output, reduction, storage, scan_op); + base_type{}.inclusive_scan(input, init, output, reduction, storage, scan_op); } /// \overload @@ -785,7 +797,8 @@ class block_scan /// \param [in] input reference to an array containing thread input values. /// \param [in] init initial value to seed the inclusive scan. /// \param [out] output reference to a thread output array. May be aliased with \p input. - /// \param [out] reduction result of reducing of all \p input values in a block. + /// \param [out] reduction result of reducing of all \p input values in a block. This does + /// not include \p init. /// \param [in] scan_op binary operation function object that will be used for scan. /// The signature of the function should be equivalent to the following: /// T f(const T &a, const T &b);. The signature does not need to have @@ -801,7 +814,7 @@ class block_scan T& reduction, BinaryFunction scan_op = BinaryFunction()) { - base_type::inclusive_scan(input, init, output, reduction, scan_op); + base_type{}.inclusive_scan(input, init, output, reduction, scan_op); } /// \brief Performs inclusive scan across threads in a block, and uses @@ -895,11 +908,11 @@ class block_scan { if(ItemsPerThread == 1) { - base_type::inclusive_scan(input[0], output[0], storage, prefix_callback_op, scan_op); + base_type{}.inclusive_scan(input[0], output[0], storage, prefix_callback_op, scan_op); } else { - base_type::inclusive_scan(input, output, storage, prefix_callback_op, scan_op); + base_type{}.inclusive_scan(input, output, storage, prefix_callback_op, scan_op); } } @@ -961,7 +974,7 @@ class block_scan storage_type& storage, BinaryFunction scan_op = BinaryFunction()) { - base_type::exclusive_scan(input, output, init, storage, scan_op); + base_type{}.exclusive_scan(input, output, init, storage, scan_op); } /// \overload @@ -988,7 +1001,7 @@ class block_scan T init, BinaryFunction scan_op = BinaryFunction()) { - base_type::exclusive_scan(input, output, init, scan_op); + base_type{}.exclusive_scan(input, output, init, scan_op); } /// \brief Performs exclusive scan and reduction across threads in a block. @@ -1054,7 +1067,7 @@ class block_scan storage_type& storage, BinaryFunction scan_op = BinaryFunction()) { - base_type::exclusive_scan(input, output, init, reduction, storage, scan_op); + base_type{}.exclusive_scan(input, output, init, reduction, storage, scan_op); } /// \overload @@ -1083,7 +1096,7 @@ class block_scan T& reduction, BinaryFunction scan_op = BinaryFunction()) { - base_type::exclusive_scan(input, output, init, reduction, scan_op); + base_type{}.exclusive_scan(input, output, init, reduction, scan_op); } /// \brief Performs exclusive scan across threads in a block, and uses @@ -1173,7 +1186,7 @@ class block_scan PrefixCallback& prefix_callback_op, BinaryFunction scan_op) { - base_type::exclusive_scan(input, output, storage, prefix_callback_op, scan_op); + base_type{}.exclusive_scan(input, output, storage, prefix_callback_op, scan_op); } /// \brief Performs exclusive scan across threads in a block. @@ -1240,11 +1253,11 @@ class block_scan { if(ItemsPerThread == 1) { - base_type::exclusive_scan(input[0], output[0], init, storage, scan_op); + base_type{}.exclusive_scan(input[0], output[0], init, storage, scan_op); } else { - base_type::exclusive_scan(input, output, init, storage, scan_op); + base_type{}.exclusive_scan(input, output, init, storage, scan_op); } } @@ -1278,11 +1291,11 @@ class block_scan { if(ItemsPerThread == 1) { - base_type::exclusive_scan(input[0], output[0], init, scan_op); + base_type{}.exclusive_scan(input[0], output[0], init, scan_op); } else { - base_type::exclusive_scan(input, output, init, scan_op); + base_type{}.exclusive_scan(input, output, init, scan_op); } } @@ -1355,11 +1368,11 @@ class block_scan { if(ItemsPerThread == 1) { - base_type::exclusive_scan(input[0], output[0], init, reduction, storage, scan_op); + base_type{}.exclusive_scan(input[0], output[0], init, reduction, storage, scan_op); } else { - base_type::exclusive_scan(input, output, init, reduction, storage, scan_op); + base_type{}.exclusive_scan(input, output, init, reduction, storage, scan_op); } } @@ -1395,11 +1408,11 @@ class block_scan { if(ItemsPerThread == 1) { - base_type::exclusive_scan(input[0], output[0], init, reduction, scan_op); + base_type{}.exclusive_scan(input[0], output[0], init, reduction, scan_op); } else { - base_type::exclusive_scan(input, output, init, reduction, scan_op); + base_type{}.exclusive_scan(input, output, init, reduction, scan_op); } } @@ -1494,15 +1507,58 @@ class block_scan { if(ItemsPerThread == 1) { - base_type::exclusive_scan(input[0], output[0], storage, prefix_callback_op, scan_op); + base_type{}.exclusive_scan(input[0], output[0], storage, prefix_callback_op, scan_op); } else { - base_type::exclusive_scan(input, output, storage, prefix_callback_op, scan_op); + base_type{}.exclusive_scan(input, output, storage, prefix_callback_op, scan_op); } } }; +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template +class block_scan +{ +private: + using block_scan_wave32 = block_scan; + using block_scan_wave64 = block_scan; + + using dispatch = detail::dispatch_wave_size; + +public: + using storage_type = typename dispatch::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto inclusive_scan(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.inclusive_scan(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto exclusive_scan(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.exclusive_scan(args...); }, args...); + } +}; +#endif // DOXYGEN_SHOULD_SKIP_THIS + END_ROCPRIM_NAMESPACE /// @} diff --git a/rocprim/include/rocprim/block/block_store_func.hpp b/rocprim/include/rocprim/block/block_store_func.hpp index 8f3f1b4cd..aad00ab5f 100644 --- a/rocprim/include/rocprim/block/block_store_func.hpp +++ b/rocprim/include/rocprim/block/block_store_func.hpp @@ -29,6 +29,8 @@ #include "../types.hpp" #include "rocprim/intrinsics/arch.hpp" +#include "../thread/thread_store.hpp" + /// \addtogroup blockmodule /// @{ @@ -276,18 +278,18 @@ void block_store_direct_striped(unsigned int flat_id, /// into a blocked arrangement on continuous memory. /// /// \ingroup blockmodule_warp_store_functions -/// The warp-striped arrangement is assumed to be (\p WarpSize * \p ItemsPerThread) items +/// The warp-striped arrangement is assumed to be (\p VirtualWaveSize * \p ItemsPerThread) items /// across a thread block. Each thread uses a \p flat_id to store a range of /// \p ItemsPerThread \p items to the thread block. /// -/// * The number of threads in the block must be a multiple of \p WarpSize. -/// * The default \p WarpSize is a hardware warpsize and is an optimal value. -/// * \p WarpSize must be a power of two and equal or less than the size of +/// * The number of threads in the block must be a multiple of \p VirtualWaveSize. +/// * The default \p VirtualWaveSize is a hardware warpsize and is an optimal value. +/// * \p VirtualWaveSize must be a power of two and equal or less than the size of /// hardware warp. -/// * Using \p WarpSize smaller than hardware warpsize could result in lower +/// * Using \p VirtualWaveSize smaller than hardware warpsize could result in lower /// performance. /// -/// \tparam WarpSize [optional] the number of threads in a warp +/// \tparam VirtualWaveSize [optional] the number of threads in a warp /// \tparam OutputIterator [inferred] an iterator type for input (can be a simple /// pointer /// \tparam T [inferred] the data type @@ -297,7 +299,7 @@ void block_store_direct_striped(unsigned int flat_id, /// \param flat_id a local flat 1D thread id in a block (tile) for the calling thread /// \param block_output the input iterator from the thread block to store to /// \param items array that data is stored to thread block -template @@ -310,18 +312,19 @@ void block_store_direct_warp_striped(unsigned int flat_id, "The type T must be such that an object of type OutputIterator " "can be dereferenced and assigned a value of type T."); - static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= arch::wavefront::max_size(), - "WarpSize must be a power of two and equal or less" + static_assert(detail::is_power_of_two(VirtualWaveSize) + && VirtualWaveSize <= arch::wavefront::max_size(), + "VirtualWaveSize must be a power of two and equal or less" "than the size of hardware warp."); - unsigned int thread_id = detail::logical_lane_id(); - unsigned int warp_id = flat_id / WarpSize; - unsigned int warp_offset = warp_id * WarpSize * ItemsPerThread; + unsigned int thread_id = detail::logical_lane_id(); + unsigned int warp_id = flat_id / VirtualWaveSize; + unsigned int warp_offset = warp_id * VirtualWaveSize * ItemsPerThread; OutputIterator thread_iter = block_output + thread_id + warp_offset; ROCPRIM_UNROLL for (unsigned int item = 0; item < ItemsPerThread; item++) { - thread_iter[item * WarpSize] = items[item]; + thread_iter[item * VirtualWaveSize] = items[item]; } } @@ -329,18 +332,18 @@ void block_store_direct_warp_striped(unsigned int flat_id, /// into a blocked arrangement on continuous memory, which is guarded by range \p valid. /// /// \ingroup blockmodule_warp_store_functions -/// The warp-striped arrangement is assumed to be (\p WarpSize * \p ItemsPerThread) items +/// The warp-striped arrangement is assumed to be (\p VirtualWaveSize * \p ItemsPerThread) items /// across a thread block. Each thread uses a \p flat_id to store a range of /// \p ItemsPerThread \p items to the thread block. /// -/// * The number of threads in the block must be a multiple of \p WarpSize. -/// * The default \p WarpSize is a hardware warpsize and is an optimal value. -/// * \p WarpSize must be a power of two and equal or less than the size of +/// * The number of threads in the block must be a multiple of \p VirtualWaveSize. +/// * The default \p VirtualWaveSize is a hardware warpsize and is an optimal value. +/// * \p VirtualWaveSize must be a power of two and equal or less than the size of /// hardware warp. -/// * Using \p WarpSize smaller than hardware warpsize could result in lower +/// * Using \p VirtualWaveSize smaller than hardware warpsize could result in lower /// performance. /// -/// \tparam WarpSize [optional] the number of threads in a warp +/// \tparam VirtualWaveSize [optional] the number of threads in a warp /// \tparam OutputIterator [inferred] an iterator type for input (can be a simple /// pointer /// \tparam T [inferred] the data type @@ -351,7 +354,7 @@ void block_store_direct_warp_striped(unsigned int flat_id, /// \param block_output the input iterator from the thread block to store to /// \param items array that data is stored to thread block /// \param valid maximum range of valid numbers to store -template @@ -365,20 +368,21 @@ void block_store_direct_warp_striped(unsigned int flat_id, "The type T must be such that an object of type OutputIterator " "can be dereferenced and assigned a value of type T."); - static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= arch::wavefront::max_size(), - "WarpSize must be a power of two and equal or less" + static_assert(detail::is_power_of_two(VirtualWaveSize) + && VirtualWaveSize <= arch::wavefront::max_size(), + "VirtualWaveSize must be a power of two and equal or less" "than the size of hardware warp."); - assert(WarpSize <= arch::wavefront::size()); + assert(VirtualWaveSize <= arch::wavefront::size()); - unsigned int thread_id = detail::logical_lane_id(); - unsigned int warp_id = flat_id / WarpSize; - unsigned int warp_offset = warp_id * WarpSize * ItemsPerThread; + unsigned int thread_id = detail::logical_lane_id(); + unsigned int warp_id = flat_id / VirtualWaveSize; + unsigned int warp_offset = warp_id * VirtualWaveSize * ItemsPerThread; OutputIterator thread_iter = block_output + thread_id + warp_offset; ROCPRIM_UNROLL for (unsigned int item = 0; item < ItemsPerThread; item++) { - unsigned int offset = item * WarpSize; + unsigned int offset = item * VirtualWaveSize; if (warp_offset + thread_id + offset < valid) { thread_iter[offset] = items[item]; @@ -386,6 +390,47 @@ void block_store_direct_warp_striped(unsigned int flat_id, } } +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto block_store_direct_blocked_cast(unsigned int flat_id, + T* block_output, + U (&items)[ItemsPerThread]) -> + typename std::enable_if::value + && (ItemsPerThread * sizeof(T)) % sizeof(V) == 0>::type +{ + static_assert(std::is_convertible::value, + "The type U must be such that it can be implicitly converted to T."); + + constexpr unsigned int vectors_per_thread = (sizeof(T) * ItemsPerThread) / sizeof(V); + + V* vector_ptr = ::rocprim::detail::bit_cast(block_output) + flat_id * vectors_per_thread; + + ROCPRIM_UNROLL + for(unsigned int item = 0; item < vectors_per_thread; item++) + { + vector_ptr[item] = *(reinterpret_cast(items) + item); + } +} + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto block_store_direct_blocked_cast(unsigned int flat_id, + T* block_output, + U (&items)[ItemsPerThread]) -> + typename std::enable_if::value + || (ItemsPerThread * sizeof(T)) % sizeof(V) != 0>::type +{ + block_store_direct_blocked(flat_id, block_output, items); +} + END_ROCPRIM_NAMESPACE /// @} diff --git a/rocprim/include/rocprim/block/detail/block_adjacent_difference_impl.hpp b/rocprim/include/rocprim/block/detail/block_adjacent_difference_impl.hpp index 63fdaaf07..2d2c22b2f 100644 --- a/rocprim/include/rocprim/block/detail/block_adjacent_difference_impl.hpp +++ b/rocprim/include/rocprim/block/detail/block_adjacent_difference_impl.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -40,55 +40,58 @@ namespace detail // index // block_discontinuity and block_adjacent difference only differ in their implementations by the // order the operators parameters are passed, so this method deals with this as well -template -ROCPRIM_DEVICE ROCPRIM_INLINE auto apply(BinaryFunction op, - const T& a, - const T& b, - unsigned int index, - bool_constant /*as_flags*/, - bool_constant /*reversed*/) -> decltype(op(b, a, index)) +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto apply(BinaryFunction op, + const T& a, + const T& b, + unsigned int index, + bool_constant /*as_flags*/, + bool_constant /*reversed*/) -> decltype(op(b, a, index)) { return op(a, b, index); } -template -ROCPRIM_DEVICE ROCPRIM_INLINE auto apply(BinaryFunction op, - const T& a, - const T& b, - unsigned int index, - bool_constant /*as_flags*/, - bool_constant /*reversed*/) - -> decltype(op(b, a, index)) +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto apply(BinaryFunction op, + const T& a, + const T& b, + unsigned int index, + bool_constant /*as_flags*/, + bool_constant /*reversed*/) -> decltype(op(b, a, index)) { return op(b, a, index); } -template -ROCPRIM_DEVICE ROCPRIM_INLINE auto apply(BinaryFunction op, - const T& a, - const T& b, - unsigned int, - bool_constant /*as_flags*/, - bool_constant /*reversed*/) -> decltype(op(b, a)) +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto apply(BinaryFunction op, + const T& a, + const T& b, + unsigned int, + bool_constant /*as_flags*/, + bool_constant /*reversed*/) -> decltype(op(b, a)) { return op(a, b); } -template -ROCPRIM_DEVICE ROCPRIM_INLINE auto apply(BinaryFunction op, - const T& a, - const T& b, - unsigned int, - bool_constant /*as_flags*/, - bool_constant /*reversed*/) -> decltype(op(b, a)) +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto apply(BinaryFunction op, + const T& a, + const T& b, + unsigned int, + bool_constant /*as_flags*/, + bool_constant /*reversed*/) -> decltype(op(b, a)) { return op(b, a); } -template +template class block_adjacent_difference_impl { public: @@ -98,20 +101,21 @@ class block_adjacent_difference_impl T items[BlockSize]; }; - template - ROCPRIM_DEVICE void apply_left(const T (&input)[ItemsPerThread], - Output (&output)[ItemsPerThread], - BinaryFunction op, - const T tile_predecessor_item, - storage_type& storage) + template + ROCPRIM_DEVICE + void apply_left(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + BinaryFunction op, + const T tile_predecessor_item, + storage_type& storage) { - static constexpr auto as_flags = bool_constant {}; - static constexpr auto reversed = bool_constant {}; + static constexpr auto as_flags = bool_constant{}; + static constexpr auto reversed = bool_constant{}; const unsigned int flat_id = ::rocprim::flat_block_thread_id(); @@ -122,25 +126,35 @@ class block_adjacent_difference_impl ROCPRIM_UNROLL for(unsigned int i = ItemsPerThread - 1; i > 0; --i) { - output[i] = detail::apply( - op, input[i - 1], input[i], flat_id * ItemsPerThread + i, as_flags, reversed); + output[i] = detail::apply(op, + input[i - 1], + input[i], + flat_id * ItemsPerThread + i, + as_flags, + reversed); } ::rocprim::syncthreads(); - if ROCPRIM_IF_CONSTEXPR (WithTilePredecessor) + if constexpr(WithTilePredecessor) { T predecessor_item = tile_predecessor_item; - if(flat_id != 0) { + if(flat_id != 0) + { predecessor_item = storage.items[flat_id - 1]; } - output[0] = detail::apply( - op, predecessor_item, input[0], flat_id * ItemsPerThread, as_flags, reversed); + output[0] = detail::apply(op, + predecessor_item, + input[0], + flat_id * ItemsPerThread, + as_flags, + reversed); } else { output[0] = get_default_item(input, 0, as_flags); - if(flat_id != 0) { + if(flat_id != 0) + { output[0] = detail::apply(op, storage.items[flat_id - 1], input[0], @@ -151,21 +165,22 @@ class block_adjacent_difference_impl } } - template - ROCPRIM_DEVICE void apply_left_partial(const T (&input)[ItemsPerThread], - Output (&output)[ItemsPerThread], - BinaryFunction op, - const T tile_predecessor_item, - const unsigned int valid_items, - storage_type& storage) + template + ROCPRIM_DEVICE + void apply_left_partial(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + BinaryFunction op, + const T tile_predecessor_item, + const unsigned int valid_items, + storage_type& storage) { - static constexpr auto as_flags = bool_constant {}; - static constexpr auto reversed = bool_constant {}; + static constexpr auto as_flags = bool_constant{}; + static constexpr auto reversed = bool_constant{}; const unsigned int flat_id = ::rocprim::flat_block_thread_id(); @@ -177,8 +192,9 @@ class block_adjacent_difference_impl for(unsigned int i = ItemsPerThread - 1; i > 0; --i) { const unsigned int index = flat_id * ItemsPerThread + i; - output[i] = get_default_item(input, i, as_flags); - if(index < valid_items) { + output[i] = get_default_item(input, i, as_flags); + if(index < valid_items) + { output[i] = detail::apply(op, input[i - 1], input[i], index, as_flags, reversed); } } @@ -186,10 +202,11 @@ class block_adjacent_difference_impl const unsigned int index = flat_id * ItemsPerThread; - if ROCPRIM_IF_CONSTEXPR (WithTilePredecessor) + if constexpr(WithTilePredecessor) { T predecessor_item = tile_predecessor_item; - if(flat_id != 0) { + if(flat_id != 0) + { predecessor_item = storage.items[flat_id - 1]; } @@ -215,20 +232,21 @@ class block_adjacent_difference_impl } } - template - ROCPRIM_DEVICE void apply_right(const T (&input)[ItemsPerThread], - Output (&output)[ItemsPerThread], - BinaryFunction op, - const T tile_successor_item, - storage_type& storage) + template + ROCPRIM_DEVICE + void apply_right(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + BinaryFunction op, + const T tile_successor_item, + storage_type& storage) { - static constexpr auto as_flags = bool_constant {}; - static constexpr auto reversed = bool_constant {}; + static constexpr auto as_flags = bool_constant{}; + static constexpr auto reversed = bool_constant{}; const unsigned int flat_id = ::rocprim::flat_block_thread_id(); @@ -239,15 +257,20 @@ class block_adjacent_difference_impl ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread - 1; ++i) { - output[i] = detail::apply( - op, input[i], input[i + 1], flat_id * ItemsPerThread + i + 1, as_flags, reversed); + output[i] = detail::apply(op, + input[i], + input[i + 1], + flat_id * ItemsPerThread + i + 1, + as_flags, + reversed); } ::rocprim::syncthreads(); - if ROCPRIM_IF_CONSTEXPR (WithTileSuccessor) + if constexpr(WithTileSuccessor) { T successor_item = tile_successor_item; - if(flat_id != BlockSize - 1) { + if(flat_id != BlockSize - 1) + { successor_item = storage.items[flat_id + 1]; } @@ -261,7 +284,8 @@ class block_adjacent_difference_impl else { output[ItemsPerThread - 1] = get_default_item(input, ItemsPerThread - 1, as_flags); - if(flat_id != BlockSize - 1) { + if(flat_id != BlockSize - 1) + { output[ItemsPerThread - 1] = detail::apply(op, input[ItemsPerThread - 1], @@ -272,19 +296,20 @@ class block_adjacent_difference_impl } } } - template - ROCPRIM_DEVICE void apply_right_partial(const T (&input)[ItemsPerThread], - Output (&output)[ItemsPerThread], - BinaryFunction op, - const unsigned int valid_items, - storage_type& storage) + template + ROCPRIM_DEVICE + void apply_right_partial(const T (&input)[ItemsPerThread], + Output (&output)[ItemsPerThread], + BinaryFunction op, + const unsigned int valid_items, + storage_type& storage) { - static constexpr auto as_flags = bool_constant {}; - static constexpr auto reversed = bool_constant {}; + static constexpr auto as_flags = bool_constant{}; + static constexpr auto reversed = bool_constant{}; const unsigned int flat_id = ::rocprim::flat_block_thread_id(); @@ -296,7 +321,7 @@ class block_adjacent_difference_impl for(unsigned int i = 0; i < ItemsPerThread - 1; ++i) { const unsigned int index = flat_id * ItemsPerThread + i + 1; - output[i] = get_default_item(input, i, as_flags); + output[i] = get_default_item(input, i, as_flags); if(index < valid_items) { output[i] = detail::apply(op, input[i], input[i + 1], index, as_flags, reversed); @@ -319,18 +344,20 @@ class block_adjacent_difference_impl } private: - template - ROCPRIM_DEVICE int get_default_item(const T (&)[ItemsPerThread], - unsigned int /*index*/, - bool_constant /*as_flags*/) + template + ROCPRIM_DEVICE + int get_default_item(const T (&)[ItemsPerThread], + unsigned int /*index*/, + bool_constant /*as_flags*/) { return 1; } - template - ROCPRIM_DEVICE T get_default_item(const T (&input)[ItemsPerThread], - const unsigned int index, - bool_constant /*as_flags*/) + template + ROCPRIM_DEVICE + T get_default_item(const T (&input)[ItemsPerThread], + const unsigned int index, + bool_constant /*as_flags*/) { return input[index]; } diff --git a/rocprim/include/rocprim/block/detail/block_radix_rank_basic.hpp b/rocprim/include/rocprim/block/detail/block_radix_rank_basic.hpp index bef343442..c2243b4cc 100644 --- a/rocprim/include/rocprim/block/detail/block_radix_rank_basic.hpp +++ b/rocprim/include/rocprim/block/detail/block_radix_rank_basic.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,7 +24,7 @@ #include "../../config.hpp" #include "../../detail/various.hpp" #include "../../functional.hpp" -#include "../../thread/radix_key_codec.hpp" +#include "../../type_traits.hpp" #include "../block_scan.hpp" @@ -33,11 +33,12 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template +template class block_radix_rank { using digit_counter_type = unsigned short; @@ -47,7 +48,8 @@ class block_radix_rank BlockSizeX, ::rocprim::block_scan_algorithm::using_warp_scan, BlockSizeY, - BlockSizeZ>; + BlockSizeZ, + TargetWaveSize>; static constexpr unsigned int block_size = BlockSizeX * BlockSizeY * BlockSizeZ; static constexpr unsigned int radix_digits = 1 << RadixBits; @@ -71,9 +73,10 @@ class block_radix_rank typename block_scan_type::storage_type block_scan; }; - ROCPRIM_DEVICE ROCPRIM_INLINE digit_counter_type& get_digit_counter(const unsigned int digit, - const unsigned int thread, - storage_type_& storage) + ROCPRIM_DEVICE ROCPRIM_INLINE + digit_counter_type& get_digit_counter(const unsigned int digit, + const unsigned int thread, + storage_type_& storage) { const unsigned int column_counter = digit % column_size; const unsigned int sub_counter = digit / column_size; @@ -82,8 +85,8 @@ class block_radix_rank return storage.digit_counters[counter]; }; - ROCPRIM_DEVICE ROCPRIM_INLINE void reset_counters(const unsigned int flat_id, - storage_type_& storage) + ROCPRIM_DEVICE ROCPRIM_INLINE + void reset_counters(const unsigned int flat_id, storage_type_& storage) { for(unsigned int i = flat_id; i < block_size * column_size; i += block_size) { @@ -91,8 +94,8 @@ class block_radix_rank } } - ROCPRIM_DEVICE ROCPRIM_INLINE void - scan_block_counters(storage_type_& storage, packed_counter_type* const packed_counters) + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan_block_counters(storage_type_& storage, packed_counter_type* const packed_counters) { packed_counter_type block_reduction = 0; ROCPRIM_UNROLL @@ -124,13 +127,13 @@ class block_radix_rank } } - ROCPRIM_DEVICE ROCPRIM_INLINE void scan_counters(const unsigned int flat_id, - storage_type_& storage) + ROCPRIM_DEVICE ROCPRIM_INLINE + void scan_counters(const unsigned int flat_id, storage_type_& storage) { packed_counter_type* const shared_counters = &storage.packed_counters[flat_id * column_size]; - if ROCPRIM_IF_CONSTEXPR(MemoizeOuterScan) + if constexpr(MemoizeOuterScan) { packed_counter_type local_counters[column_size]; ROCPRIM_UNROLL @@ -154,10 +157,11 @@ class block_radix_rank } template - ROCPRIM_DEVICE void rank_keys_impl(const Key (&keys)[ItemsPerThread], - unsigned int (&ranks)[ItemsPerThread], - storage_type_& storage, - DigitExtractor digit_extractor) + ROCPRIM_DEVICE + void rank_keys_impl(const Key (&keys)[ItemsPerThread], + unsigned int (&ranks)[ItemsPerThread], + storage_type_& storage, + DigitExtractor digit_extractor) { static_assert(block_size * ItemsPerThread < 1u << 16, "The maximum amout of items that block_radix_rank can rank is 2**16."); @@ -191,13 +195,15 @@ class block_radix_rank } template - ROCPRIM_DEVICE void rank_keys_impl(const Key (&keys)[ItemsPerThread], - unsigned int (&ranks)[ItemsPerThread], - storage_type_& storage, - const unsigned int begin_bit, - const unsigned int pass_bits) + ROCPRIM_DEVICE + void rank_keys_impl(const Key (&keys)[ItemsPerThread], + unsigned int (&ranks)[ItemsPerThread], + storage_type_& storage, + const unsigned int begin_bit, + const unsigned int pass_bits) { - using key_codec = ::rocprim::radix_key_codec; + using key_codec + = decltype(::rocprim::traits::get().template radix_key_codec()); using bit_key_type = typename key_codec::bit_key_type; bit_key_type bit_keys[ItemsPerThread]; @@ -215,9 +221,10 @@ class block_radix_rank } template - ROCPRIM_DEVICE void digit_prefix_count(unsigned int (&prefix)[digits_per_thread], - unsigned int (&counts)[digits_per_thread], - storage_type_& storage) + ROCPRIM_DEVICE + void digit_prefix_count(unsigned int (&prefix)[digits_per_thread], + unsigned int (&counts)[digits_per_thread], + storage_type_& storage) { const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); @@ -245,39 +252,43 @@ class block_radix_rank ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_POP template - ROCPRIM_DEVICE void rank_keys(const Key (&keys)[ItemsPerThread], - unsigned int (&ranks)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int pass_bits = RadixBits) + ROCPRIM_DEVICE + void rank_keys(const Key (&keys)[ItemsPerThread], + unsigned int (&ranks)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int pass_bits = RadixBits) { rank_keys_impl(keys, ranks, storage.get(), begin_bit, pass_bits); } template - ROCPRIM_DEVICE void rank_keys_desc(const Key (&keys)[ItemsPerThread], - unsigned int (&ranks)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int pass_bits = RadixBits) + ROCPRIM_DEVICE + void rank_keys_desc(const Key (&keys)[ItemsPerThread], + unsigned int (&ranks)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int pass_bits = RadixBits) { rank_keys_impl(keys, ranks, storage.get(), begin_bit, pass_bits); } template - ROCPRIM_DEVICE void rank_keys(const Key (&keys)[ItemsPerThread], - unsigned int (&ranks)[ItemsPerThread], - storage_type& storage, - DigitExtractor digit_extractor) + ROCPRIM_DEVICE + void rank_keys(const Key (&keys)[ItemsPerThread], + unsigned int (&ranks)[ItemsPerThread], + storage_type& storage, + DigitExtractor digit_extractor) { rank_keys_impl(keys, ranks, storage.get(), digit_extractor); } template - ROCPRIM_DEVICE void rank_keys_desc(const Key (&keys)[ItemsPerThread], - unsigned int (&ranks)[ItemsPerThread], - storage_type& storage, - DigitExtractor digit_extractor) + ROCPRIM_DEVICE + void rank_keys_desc(const Key (&keys)[ItemsPerThread], + unsigned int (&ranks)[ItemsPerThread], + storage_type& storage, + DigitExtractor digit_extractor) { rank_keys_impl(keys, ranks, @@ -290,12 +301,13 @@ class block_radix_rank } template - ROCPRIM_DEVICE void rank_keys(const Key (&keys)[ItemsPerThread], - unsigned int (&ranks)[ItemsPerThread], - storage_type& storage, - DigitExtractor digit_extractor, - unsigned int (&prefix)[digits_per_thread], - unsigned int (&counts)[digits_per_thread]) + ROCPRIM_DEVICE + void rank_keys(const Key (&keys)[ItemsPerThread], + unsigned int (&ranks)[ItemsPerThread], + storage_type& storage, + DigitExtractor digit_extractor, + unsigned int (&prefix)[digits_per_thread], + unsigned int (&counts)[digits_per_thread]) { rank_keys(keys, ranks, storage, digit_extractor); digit_prefix_count(prefix, counts, storage.get()); diff --git a/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp b/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp index 7f3b9da6c..2d7c9d8e5 100644 --- a/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp +++ b/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp @@ -26,8 +26,6 @@ #include "../../functional.hpp" #include "../../types.hpp" -#include "../../thread/radix_key_codec.hpp" - #include "../block_scan.hpp" #include "../config.hpp" #include "rocprim/intrinsics/arch.hpp" @@ -37,11 +35,12 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template +template class block_radix_rank_match { using digit_counter_type = unsigned int; @@ -50,7 +49,8 @@ class block_radix_rank_match BlockSizeX, ::rocprim::block_scan_algorithm::using_warp_scan, BlockSizeY, - BlockSizeZ>; + BlockSizeZ, + TargetWaveSize>; static constexpr unsigned int block_size = BlockSizeX * BlockSizeY * BlockSizeZ; static constexpr unsigned int radix_digits = 1 << RadixBits; @@ -58,8 +58,8 @@ class block_radix_rank_match struct unpadded_config { // min size is used because we allocate based on the number of warps - static constexpr unsigned int warps - = ::rocprim::detail::ceiling_div(block_size, arch::wavefront::min_size()); + static constexpr unsigned int warps = ::rocprim::detail::ceiling_div( + block_size, arch::wavefront::size_from_target()); }; struct padded_config @@ -204,7 +204,8 @@ class block_radix_rank_match const unsigned int begin_bit, const unsigned int pass_bits) { - using key_codec = ::rocprim::radix_key_codec; + using key_codec + = decltype(::rocprim::traits::get().template radix_key_codec()); using bit_key_type = typename key_codec::bit_key_type; bit_key_type bit_keys[ItemsPerThread]; diff --git a/rocprim/include/rocprim/block/detail/block_reduce_raking_reduce.hpp b/rocprim/include/rocprim/block/detail/block_reduce_raking_reduce.hpp index 98201c496..7face8f41 100644 --- a/rocprim/include/rocprim/block/detail/block_reduce_raking_reduce.hpp +++ b/rocprim/include/rocprim/block/detail/block_reduce_raking_reduce.hpp @@ -100,10 +100,11 @@ class fast_array sizeof(int32_t))>> #endif // DOXYGEN_SHOULD_SKIP_THIS template + unsigned int BlockSizeX, + unsigned int BlockSizeY, + unsigned int BlockSizeZ, + arch::wavefront::target TargetWaveSize, + bool CommutativeOnly = false> class block_reduce_raking_reduce { static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; @@ -111,7 +112,7 @@ class block_reduce_raking_reduce // Warp reduce, warp_reduce_crosslane does not require shared memory (storage), but // logical warp size must be a power of two. static constexpr unsigned int warp_size_ - = detail::get_min_warp_size(BlockSize, ::rocprim::arch::wavefront::min_size()); + = detail::get_min_warp_size(BlockSize, arch::wavefront::size_from_target()); static constexpr unsigned int segment_len = ceiling_div(BlockSize, warp_size_); diff --git a/rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp b/rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp index 50b52820c..27b5c252f 100644 --- a/rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp +++ b/rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp @@ -26,8 +26,8 @@ #include "../../config.hpp" #include "../../detail/various.hpp" -#include "../../intrinsics.hpp" #include "../../functional.hpp" +#include "../../intrinsics.hpp" #include "../../warp/warp_reduce.hpp" @@ -36,24 +36,23 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class T, - unsigned int BlockSizeX, - unsigned int BlockSizeY, - unsigned int BlockSizeZ -> +template class block_reduce_warp_reduce { static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; // Select warp size static constexpr unsigned int warp_size_ - = detail::get_min_warp_size(BlockSize, ::rocprim::arch::wavefront::min_size()); + = detail::get_min_warp_size(BlockSize, arch::wavefront::size_from_target()); // Number of warps in block static constexpr unsigned int warps_no_ = (BlockSize + warp_size_ - 1) / warp_size_; // Check if we have to pass number of valid items into warp reduction primitive static constexpr bool block_size_is_warp_multiple_ = ((BlockSize % warp_size_) == 0); - static constexpr bool warps_no_is_pow_of_two_ = detail::is_power_of_two(warps_no_); + static constexpr bool warps_no_is_pow_of_two_ = detail::is_power_of_two(warps_no_); // typedef of warp_reduce primitive that will be used to perform warp-level // reduce operation on input values. @@ -62,9 +61,8 @@ class block_reduce_warp_reduce using warp_reduce_input_type = ::rocprim::detail::warp_reduce_crosslane; // typedef of warp_reduce primitive that will be used to perform reduction // of results of warp-level reduction. - using warp_reduce_output_type = ::rocprim::detail::warp_reduce_crosslane< - T, detail::next_power_of_two(warps_no_), false - >; + using warp_reduce_output_type + = ::rocprim::detail::warp_reduce_crosslane; struct storage_type_ { @@ -78,15 +76,13 @@ class block_reduce_warp_reduce template ROCPRIM_DEVICE ROCPRIM_INLINE - void reduce(T input, - T& output, - storage_type& storage, - BinaryFunction reduce_op) + void reduce(T input, T& output, storage_type& storage, BinaryFunction reduce_op) { - this->reduce_impl( - ::rocprim::flat_block_thread_id(), - input, output, storage, reduce_op - ); + this->reduce_impl(::rocprim::flat_block_thread_id(), + input, + output, + storage, + reduce_op); } template @@ -102,8 +98,8 @@ class block_reduce_warp_reduce template ROCPRIM_DEVICE ROCPRIM_INLINE void reduce(T (&input)[ItemsPerThread], - T& output, - storage_type& storage, + T& output, + storage_type& storage, BinaryFunction reduce_op) { // Reduce thread items @@ -116,12 +112,11 @@ class block_reduce_warp_reduce // Reduction of reduced values to get partials const auto flat_tid = ::rocprim::flat_block_thread_id(); - this->reduce_impl( - flat_tid, - thread_input, output, // input, output - storage, - reduce_op - ); + this->reduce_impl(flat_tid, + thread_input, + output, // input, output + storage, + reduce_op); } template @@ -136,16 +131,18 @@ class block_reduce_warp_reduce template ROCPRIM_DEVICE ROCPRIM_INLINE - void reduce(T input, - T& output, - unsigned int valid_items, - storage_type& storage, + void reduce(T input, + T& output, + unsigned int valid_items, + storage_type& storage, BinaryFunction reduce_op) { - this->reduce_impl( - ::rocprim::flat_block_thread_id(), - input, output, valid_items, storage, reduce_op - ); + this->reduce_impl(::rocprim::flat_block_thread_id(), + input, + output, + valid_items, + storage, + reduce_op); } template @@ -163,25 +160,25 @@ class block_reduce_warp_reduce template ROCPRIM_DEVICE ROCPRIM_INLINE void reduce_impl(const unsigned int flat_tid, - T input, - T& output, - storage_type& storage, - BinaryFunction reduce_op) + T input, + T& output, + storage_type& storage, + BinaryFunction reduce_op) { - const auto warp_id = ::rocprim::warp_id(flat_tid); - const auto lane_id = ::rocprim::lane_id(); + const auto warp_id = ::rocprim::warp_id(flat_tid); + const auto lane_id = ::rocprim::lane_id(); const unsigned int warp_offset = warp_id * warp_size_; - const unsigned int num_valid = - (warp_offset < BlockSize) ? BlockSize - warp_offset : 0; - storage_type_& storage_ = storage.get(); + const unsigned int num_valid = (warp_offset < BlockSize) ? BlockSize - warp_offset : 0; + storage_type_& storage_ = storage.get(); // Perform warp reduce - warp_reduce( - input, output, num_valid, reduce_op - ); + warp_reduce(input, + output, + num_valid, + reduce_op); // Final reduction across warps is only required if there is more than 1 warp - if ROCPRIM_IF_CONSTEXPR (warps_no_ > 1) + if constexpr(warps_no_ > 1) { // i-th warp will have its partial stored in storage_.warp_partials[i-1] if(lane_id == 0) @@ -205,54 +202,41 @@ class block_reduce_warp_reduce template ROCPRIM_DEVICE ROCPRIM_INLINE - auto warp_reduce(T input, - T& output, - const unsigned int valid_items, - BinaryFunction reduce_op) + auto warp_reduce(T input, T& output, const unsigned int valid_items, BinaryFunction reduce_op) -> typename std::enable_if::type { - WarpReduce().reduce( - input, output, valid_items, reduce_op - ); + WarpReduce().reduce(input, output, valid_items, reduce_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - auto warp_reduce(T input, - T& output, - const unsigned int valid_items, - BinaryFunction reduce_op) + auto warp_reduce(T input, T& output, const unsigned int valid_items, BinaryFunction reduce_op) -> typename std::enable_if::type { - (void) valid_items; - WarpReduce().reduce( - input, output, reduce_op - ); + (void)valid_items; + WarpReduce().reduce(input, output, reduce_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE void reduce_impl(const unsigned int flat_tid, - T input, - T& output, + T input, + T& output, const unsigned int valid_items, - storage_type& storage, - BinaryFunction reduce_op) + storage_type& storage, + BinaryFunction reduce_op) { - const auto warp_id = ::rocprim::warp_id(flat_tid); - const auto lane_id = ::rocprim::lane_id(); + const auto warp_id = ::rocprim::warp_id(flat_tid); + const auto lane_id = ::rocprim::lane_id(); const unsigned int warp_offset = warp_id * warp_size_; - const unsigned int num_valid = - (warp_offset < valid_items) ? valid_items - warp_offset : 0; - storage_type_& storage_ = storage.get(); + const unsigned int num_valid = (warp_offset < valid_items) ? valid_items - warp_offset : 0; + storage_type_& storage_ = storage.get(); // Perform warp reduce - warp_reduce_input_type().reduce( - input, output, num_valid, reduce_op - ); + warp_reduce_input_type().reduce(input, output, num_valid, reduce_op); // Final reduction across warps is only required if there is more than 1 warp - if ROCPRIM_IF_CONSTEXPR (warps_no_ > 1) + if constexpr(warps_no_ > 1) { // i-th warp will have its partial stored in storage_.warp_partials[i-1] if(lane_id == 0) @@ -267,9 +251,7 @@ class block_reduce_warp_reduce auto warp_partial = storage_.warp_partials[lane_id]; unsigned int valid_warps_no = (valid_items + warp_size_ - 1) / warp_size_; - warp_reduce_output_type().reduce( - warp_partial, output, valid_warps_no, reduce_op - ); + warp_reduce_output_type().reduce(warp_partial, output, valid_warps_no, reduce_op); } } } diff --git a/rocprim/include/rocprim/block/detail/block_scan_reduce_then_scan.hpp b/rocprim/include/rocprim/block/detail/block_scan_reduce_then_scan.hpp index 8c93b16c3..e47ce9245 100644 --- a/rocprim/include/rocprim/block/detail/block_scan_reduce_then_scan.hpp +++ b/rocprim/include/rocprim/block/detail/block_scan_reduce_then_scan.hpp @@ -29,6 +29,9 @@ #include "../../intrinsics.hpp" #include "../../functional.hpp" +#include "../../intrinsics/thread.hpp" +#include "../../thread/thread_reduce.hpp" +#include "../../thread/thread_scan.hpp" #include "../../warp/warp_scan.hpp" BEGIN_ROCPRIM_NAMESPACE @@ -36,24 +39,23 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class T, - unsigned int BlockSizeX, - unsigned int BlockSizeY, - unsigned int BlockSizeZ -> +template class block_scan_reduce_then_scan { static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; // Number of items to reduce per thread static constexpr unsigned int thread_reduction_size_ - = (BlockSize + ::rocprim::arch::wavefront::min_size() - 1) - / ::rocprim::arch::wavefront::min_size(); + = (BlockSize + arch::wavefront::size_from_target() - 1) + / arch::wavefront::size_from_target(); // Warp scan, warp_scan_crosslane does not require shared memory (storage), but // logical warp size must be a power of two. static constexpr unsigned int warp_size_ - = detail::get_min_warp_size(BlockSize, ::rocprim::arch::wavefront::min_size()); + = detail::get_min_warp_size(BlockSize, arch::wavefront::size_from_target()); using warp_scan_prefix_type = ::rocprim::detail::warp_scan_crosslane; // Minimize LDS bank conflicts @@ -147,12 +149,7 @@ class block_scan_reduce_then_scan BinaryFunction scan_op) { // Reduce thread items - T thread_input = input[0]; - ROCPRIM_UNROLL - for(unsigned int i = 1; i < ItemsPerThread; i++) - { - thread_input = scan_op(thread_input, input[i]); - } + T thread_input = ::rocprim::thread_reduce(input, scan_op); // Scan of reduced values to get prefixes const auto flat_tid = ::rocprim::flat_block_thread_id(); @@ -164,14 +161,7 @@ class block_scan_reduce_then_scan ); // Include prefix (first thread does not have prefix) - output[0] = input[0]; - if(flat_tid != 0) output[0] = scan_op(thread_input, input[0]); - // Final thread-local scan - ROCPRIM_UNROLL - for(unsigned int i = 1; i < ItemsPerThread; i++) - { - output[i] = scan_op(output[i-1], input[i]); - } + ::rocprim::thread_scan_inclusive(input, output, scan_op, thread_input, flat_tid > 0); } template @@ -192,33 +182,8 @@ class block_scan_reduce_then_scan storage_type& storage, BinaryFunction scan_op) { - // Reduce thread items - T thread_input = input[0]; - ROCPRIM_UNROLL - for(unsigned int i = 1; i < ItemsPerThread; i++) - { - thread_input = scan_op(thread_input, input[i]); - } - - // Scan of reduced values to get prefixes - const auto flat_tid = ::rocprim::flat_block_thread_id(); - // Calculates inclusive scan, result for each thread is stored in storage_.threads[flat_tid] - this->exclusive_scan_init_impl(flat_tid, - thread_input, - thread_input, // input, output - init, - storage, - scan_op); - - // Include prefix (first thread has init as prefix) - output[0] = scan_op(thread_input, input[0]); - - // Final thread-local scan - ROCPRIM_UNROLL - for(unsigned int i = 1; i < ItemsPerThread; i++) - { - output[i] = scan_op(output[i - 1], input[i]); - } + this->inclusive_scan(input, output, storage, scan_op); + apply_init(init, output, scan_op); } template @@ -266,10 +231,8 @@ class block_scan_reduce_then_scan storage_type& storage, BinaryFunction scan_op) { - storage_type_& storage_ = storage.get(); - this->inclusive_scan(input, init, output, storage, scan_op); - // Save reduction result - reduction = storage_.threads[index(BlockSize - 1)]; + this->inclusive_scan(input, output, reduction, storage, scan_op); + apply_init(init, output, scan_op); } template @@ -298,12 +261,7 @@ class block_scan_reduce_then_scan { storage_type_& storage_ = storage.get(); // Reduce thread items - T thread_input = input[0]; - ROCPRIM_UNROLL - for(unsigned int i = 1; i < ItemsPerThread; i++) - { - thread_input = scan_op(thread_input, input[i]); - } + T thread_input = ::rocprim::thread_reduce(input, scan_op); // Scan of reduced values to get prefixes const auto flat_tid = ::rocprim::flat_block_thread_id(); @@ -601,7 +559,12 @@ class block_scan_reduce_then_scan thread_reduction = input; } + ::rocprim::wave_barrier(); + storage_.threads[idx_start] = thread_reduction; + + ::rocprim::wave_barrier(); + ROCPRIM_UNROLL for(unsigned int i = 1; i < thread_reduction_size_; i++) { @@ -649,7 +612,12 @@ class block_scan_reduce_then_scan thread_reduction = scan_op(init, input); } + ::rocprim::wave_barrier(); + storage_.threads[idx_start] = thread_reduction; + + ::rocprim::wave_barrier(); + ROCPRIM_UNROLL for(unsigned int i = 1; i < thread_reduction_size_; i++) { @@ -758,11 +726,22 @@ class block_scan_reduce_then_scan // Change index to minimize LDS bank conflicts if necessary ROCPRIM_DEVICE ROCPRIM_INLINE - unsigned int index(unsigned int n) const + static unsigned int index(unsigned int n) { // Move every 32-bank wide "row" (32 banks * 4 bytes) by one item return has_bank_conflicts_ ? (n + (n/banks_no_)) : n; } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + static void apply_init(const T& init, T (&items)[N], F scan_op) + { + ROCPRIM_UNROLL + for(int i = 0; i < N; ++i) + { + items[i] = scan_op(init, items[i]); + } + } }; } // end namespace detail diff --git a/rocprim/include/rocprim/block/detail/block_scan_warp_scan.hpp b/rocprim/include/rocprim/block/detail/block_scan_warp_scan.hpp index 17036f0d5..5842caa3e 100644 --- a/rocprim/include/rocprim/block/detail/block_scan_warp_scan.hpp +++ b/rocprim/include/rocprim/block/detail/block_scan_warp_scan.hpp @@ -26,9 +26,9 @@ #include "../../config.hpp" #include "../../detail/various.hpp" -#include "../../intrinsics.hpp" #include "../../functional.hpp" - +#include "../../thread/thread_reduce.hpp" +#include "../../thread/thread_scan.hpp" #include "../../warp/warp_scan.hpp" BEGIN_ROCPRIM_NAMESPACE @@ -36,18 +36,17 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class T, - unsigned int BlockSizeX, - unsigned int BlockSizeY, - unsigned int BlockSizeZ -> +template class block_scan_warp_scan { static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; // Select warp size static constexpr unsigned int warp_size_ - = detail::get_min_warp_size(BlockSize, ::rocprim::arch::wavefront::min_size()); + = detail::get_min_warp_size(BlockSize, arch::wavefront::size_from_target()); // Number of warps in block static constexpr unsigned int warps_no_ = (BlockSize + warp_size_ - 1) / warp_size_; @@ -158,12 +157,7 @@ class block_scan_warp_scan BinaryFunction scan_op) { // Reduce thread items - T thread_input = input[0]; - ROCPRIM_UNROLL - for(unsigned int i = 1; i < ItemsPerThread; i++) - { - thread_input = scan_op(thread_input, input[i]); - } + T thread_input = ::rocprim::thread_reduce(input, scan_op); // Scan of reduced values to get prefixes const auto flat_tid = ::rocprim::flat_block_thread_id(); @@ -174,19 +168,8 @@ class block_scan_warp_scan scan_op ); - // Include prefix (first thread does not have prefix) - output[0] = input[0]; - if(flat_tid != 0) - { - output[0] = scan_op(thread_input, input[0]); - } - - // Final thread-local scan - ROCPRIM_UNROLL - for(unsigned int i = 1; i < ItemsPerThread; i++) - { - output[i] = scan_op(output[i-1], input[i]); - } + // Include only the 'thread_input' prefix if 'flat_tid' > 0 + ::rocprim::thread_scan_inclusive(input, output, scan_op, thread_input, flat_tid > 0); } template @@ -207,32 +190,8 @@ class block_scan_warp_scan storage_type& storage, BinaryFunction scan_op) { - // Reduce thread items - T thread_input = input[0]; - ROCPRIM_UNROLL - for(unsigned int i = 1; i < ItemsPerThread; i++) - { - thread_input = scan_op(thread_input, input[i]); - } - - // Scan of reduced values to get prefixes - const auto flat_tid = ::rocprim::flat_block_thread_id(); - this->exclusive_scan_init_impl(flat_tid, - thread_input, - thread_input, // input, output - init, - storage, - scan_op); - - // Include prefix (first thread has init as prefix) - output[0] = scan_op(thread_input, input[0]); - - // Final thread-local scan - ROCPRIM_UNROLL - for(unsigned int i = 1; i < ItemsPerThread; i++) - { - output[i] = scan_op(output[i - 1], input[i]); - } + this->inclusive_scan(input, output, storage, scan_op); + apply_init(init, output, scan_op); } template @@ -280,10 +239,8 @@ class block_scan_warp_scan storage_type& storage, BinaryFunction scan_op) { - storage_type_& storage_ = storage.get(); - this->inclusive_scan(input, init, output, storage, scan_op); - // Save reduction result - reduction = storage_.warp_prefixes[warps_no_ - 1]; + this->inclusive_scan(input, output, reduction, storage, scan_op); + apply_init(init, output, scan_op); } template @@ -312,12 +269,7 @@ class block_scan_warp_scan { storage_type_& storage_ = storage.get(); // Reduce thread items - T thread_input = input[0]; - ROCPRIM_UNROLL - for(unsigned int i = 1; i < ItemsPerThread; i++) - { - thread_input = scan_op(thread_input, input[i]); - } + T thread_input = ::rocprim::thread_reduce(input, scan_op); // Scan of reduced values to get prefixes const auto flat_tid = ::rocprim::flat_block_thread_id(); @@ -578,7 +530,8 @@ class block_scan_warp_scan T& output, storage_type& storage, BinaryFunction scan_op) -> - typename std::enable_if<(BlockSize_ > ::rocprim::arch::wavefront::min_size())>::type + typename std::enable_if< + (BlockSize_ > ::rocprim::arch::wavefront::size_from_target())>::type { storage_type_& storage_ = storage.get(); // Perform warp scan @@ -604,7 +557,8 @@ class block_scan_warp_scan ROCPRIM_DEVICE ROCPRIM_INLINE auto inclusive_scan_impl( unsigned int flat_tid, T input, T& output, storage_type& storage, BinaryFunction scan_op) -> - typename std::enable_if ::rocprim::arch::wavefront::min_size())>::type + typename std::enable_if< + !(BlockSize_ > ::rocprim::arch::wavefront::size_from_target())>::type { (void) storage; (void) flat_tid; @@ -631,7 +585,8 @@ class block_scan_warp_scan T init, storage_type& storage, BinaryFunction scan_op) -> - typename std::enable_if<(BlockSize_ > ::rocprim::arch::wavefront::min_size())>::type + typename std::enable_if< + (BlockSize_ > ::rocprim::arch::wavefront::size_from_target())>::type { storage_type_& storage_ = storage.get(); // Perform warp scan on input values @@ -671,7 +626,8 @@ class block_scan_warp_scan T init, storage_type& storage, BinaryFunction scan_op) -> - typename std::enable_if ::rocprim::arch::wavefront::min_size())>::type + typename std::enable_if< + !(BlockSize_ > ::rocprim::arch::wavefront::size_from_target())>::type { (void) flat_tid; (void) storage; @@ -698,87 +654,6 @@ class block_scan_warp_scan } } - // Exclusive scan with initial value when BlockSize is bigger than warp_size - // Warp prefixes stored in storage_.warp_prefixes include the initial value - template - ROCPRIM_DEVICE ROCPRIM_INLINE - auto exclusive_scan_init_impl(const unsigned int flat_tid, - T input, - T& output, - T init, - storage_type& storage, - BinaryFunction scan_op) -> - typename std::enable_if<(BlockSize_ > ::rocprim::arch::wavefront::min_size())>::type - { - storage_type_& storage_ = storage.get(); - // Perform warp scan on input values with init seed - warp_scan_input_type().inclusive_scan( - // not using shared mem, see note in storage_type - input, - output, - scan_op); - - // i-th warp will have its prefix stored in storage_.warp_prefixes[i-1] - const auto warp_id = ::rocprim::warp_id(flat_tid); - this->calculate_warp_prefixes(flat_tid, warp_id, output, init, storage, scan_op); - - // Include initial value in warp prefixes, and fix warp prefixes - // for exclusive scan (first warp prefix is init) - auto warp_prefix = init; - if(warp_id != 0) - { - warp_prefix = storage_.warp_prefixes[warp_id - 1]; - } - - // Use warp prefix to calculate the final scan results for every thread - output = scan_op(warp_prefix, output); // include warp prefix in scan results - output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results - if(::rocprim::lane_id() == 0) - { - output = warp_prefix; - } - } - - // Exclusive scan with initial value when BlockSize is less than warp_size. - // Warp prefixes stored in storage_.warp_prefixes include the initial value - // When BlockSize is less than warp_size we dont need the extra prefix calculations. - template - ROCPRIM_DEVICE ROCPRIM_INLINE - auto exclusive_scan_init_impl(const unsigned int flat_tid, - T input, - T& output, - T init, - storage_type& storage, - BinaryFunction scan_op) -> - typename std::enable_if ::rocprim::arch::wavefront::min_size())>::type - { - (void)flat_tid; - (void)storage; - (void)init; - storage_type_& storage_ = storage.get(); - // Perform warp scan on input values with init seed - warp_scan_input_type().inclusive_scan( - // not using shared mem, see note in storage_type - input, - output, - scan_op, - init); - - if(flat_tid == BlockSize_ - 1) - { - storage_.warp_prefixes[0] = output; - } - ::rocprim::syncthreads(); - - // Use warp prefix to calculate the final scan results for every thread - // output = scan_op(init, output); // include warp prefix in scan results - output = warp_shuffle_up(output, 1, warp_size_); // shift to get exclusive results - if(::rocprim::lane_id() == 0) - { - output = init; - } - } - // Exclusive scan with unknown initial value template ROCPRIM_DEVICE ROCPRIM_INLINE @@ -787,7 +662,8 @@ class block_scan_warp_scan T& output, storage_type& storage, BinaryFunction scan_op) -> - typename std::enable_if<(BlockSize_ > ::rocprim::arch::wavefront::min_size())>::type + typename std::enable_if< + (BlockSize_ > ::rocprim::arch::wavefront::size_from_target())>::type { storage_type_& storage_ = storage.get(); // Perform warp scan on input values @@ -823,7 +699,8 @@ class block_scan_warp_scan T& output, storage_type& storage, BinaryFunction scan_op) -> - typename std::enable_if ::rocprim::arch::wavefront::min_size())>::type + typename std::enable_if< + !(BlockSize_ > ::rocprim::arch::wavefront::size_from_target())>::type { (void) flat_tid; (void) storage; @@ -930,6 +807,17 @@ class block_scan_warp_scan ::rocprim::syncthreads(); return storage_.warp_prefixes[warps_no_ - 1]; } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + static void apply_init(const T& init, T (&items)[N], F scan_op) + { + ROCPRIM_UNROLL + for(int i = 0; i < N; ++i) + { + items[i] = scan_op(init, items[i]); + } + } }; } // end namespace detail diff --git a/rocprim/include/rocprim/config.hpp b/rocprim/include/rocprim/config.hpp index 828b82133..edb947e87 100644 --- a/rocprim/include/rocprim/config.hpp +++ b/rocprim/include/rocprim/config.hpp @@ -62,12 +62,8 @@ END_ROCPRIM_INLINE_NAMESPACE \ } /* namespace rocprim */ -#if __cplusplus == 201402L - #warning "rocPRIM C++14 will be deprecated in the next major release" -#endif - -#if __cplusplus < 201402L - #error "rocPRIM requires at least C++14" +#if __cplusplus < 201703L + #error "rocPRIM requires at least C++17" #endif #if !defined(ROCPRIM_DEVICE) || defined(DOXYGEN_DOCUMENTATION_BUILD) @@ -107,31 +103,44 @@ #undef ROCPRIM_TARGET_CDNA1 #undef ROCPRIM_TARGET_CDNA2 #undef ROCPRIM_TARGET_CDNA3 -#undef ROCPRIM_TARGET_CDNA4 +#undef ROCPRIM_TARGET_UNKNOWN // See https://llvm.org/docs/AMDGPUUsage.html#instructions -#if defined(__gfx950__) - #define ROCPRIM_TARGET_CDNA4 1 -#elif defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__) #define ROCPRIM_TARGET_CDNA3 1 #elif defined(__gfx90a__) #define ROCPRIM_TARGET_CDNA2 1 #elif defined(__gfx908__) #define ROCPRIM_TARGET_CDNA1 1 #elif defined(__gfx900__) || defined(__gfx902__) || defined(__gfx904__) || defined(__gfx906__) \ - || defined(__gfx90c__) + || defined(__gfx90c__) || defined(__gfx9_generic__) #define ROCPRIM_TARGET_GCN5 1 -#elif defined(__GFX12__) +#elif defined(__GFX12__) || defined(__gfx12_generic__) #define ROCPRIM_TARGET_RDNA4 1 -#elif defined(__GFX11__) +#elif defined(__GFX11__) || defined(__gfx11_generic__) #define ROCPRIM_TARGET_RDNA3 1 #elif defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) \ - || defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) + || defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) \ + || defined(__gfx10_3_generic__) #define ROCPRIM_TARGET_RDNA2 1 -#elif defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) || defined(__gfx1013__) +#elif defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) || defined(__gfx1013__) \ + || defined(__gfx10_1_generic__) #define ROCPRIM_TARGET_RDNA1 1 #elif defined(__GFX8__) #define ROCPRIM_TARGET_GCN3 1 +#elif defined(__SPIRV__) + #define ROCPRIM_TARGET_SPIRV 1 +#elif defined(__HIP_DEVICE_COMPILE__) + // Double check the build target for typos otherwise please submit an issue or pull request! + #warning "unknown build target" + #define ROCPRIM_TARGET_UNKNOWN 1 +#endif + +// SPIR-V and unknown targets do not support 128-bit atomics. +#if defined(ROCPRIM_TARGET_UKNOWN) || defined(ROCPRIM_TARGET_SPIRV) + #define ROCPRIM_MAX_ATOMIC_SIZE 8 +#else + #define ROCPRIM_MAX_ATOMIC_SIZE 16 #endif // DPP is supported only after Volcanic Islands (GFX8+) @@ -142,7 +151,7 @@ #define ROCPRIM_DETAIL_HAS_DPP 1 #endif -#if !defined(ROCPRIM_DISABLE_DPP) && defined(ROCPRIM_DETAIL_HAS_DPP) +#if !defined(ROCPRIM_DISABLE_DPP) && defined(ROCPRIM_DETAIL_HAS_DPP) && !ROCPRIM_TARGET_SPIRV #define ROCPRIM_DETAIL_USE_DPP 1 #else #define ROCPRIM_DETAIL_USE_DPP 0 @@ -194,12 +203,6 @@ #define ROCPRIM_GRID_SIZE_LIMIT std::numeric_limits::max() #endif -#if __cpp_if_constexpr >= 201606 - #define ROCPRIM_IF_CONSTEXPR constexpr -#else - #define ROCPRIM_IF_CONSTEXPR -#endif - // Copyright 2001 John Maddock. // Copyright 2017 Peter Dimov. // @@ -240,21 +243,6 @@ #define ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT(expression, message) #endif -/// \brief Clang predefined macro for device code on AMD GPU targets, either 32 or 64. -/// It is undefined behavior to use this macro in host code when compiling with Clang. -#ifndef __AMDGCN_WAVEFRONT_SIZE - #define __AMDGCN_WAVEFRONT_SIZE 64 -#endif - -/// \brief Wavefront size, either 32 or 64. May be defined by compiler flags when compiling -/// with Clang if the value is equal to the wavefront size of all AMD GPU architectures -/// currently being compiled for. -/// -/// Only defined in device code unless defined by compiler flags as described above. -#ifndef ROCPRIM_WAVEFRONT_SIZE - #define ROCPRIM_WAVEFRONT_SIZE __AMDGCN_WAVEFRONT_SIZE -#endif - // Helper macros to disable warnings in clang #ifdef __clang__ #define ROCPRIM_PRAGMA_TO_STR(x) _Pragma(#x) diff --git a/rocprim/include/rocprim/detail/match_result_type.hpp b/rocprim/include/rocprim/detail/match_result_type.hpp deleted file mode 100644 index 40b928982..000000000 --- a/rocprim/include/rocprim/detail/match_result_type.hpp +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#ifndef ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ -#define ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ - -#include "../config.hpp" - -#include "../type_traits.hpp" - -ROCPRIM_PRAGMA_MESSAGE("Internal 'match_result_type.hpp'-header has been depracated. Please " - "include 'rocprim/type_traits.hpp' instead!"); - -BEGIN_ROCPRIM_NAMESPACE -namespace detail -{ - -template -using invoke_result [[deprecated("Use 'rocprim::invoke_result' instead!")]] -= rocprim::invoke_result; - -template -using match_result [[deprecated("Use 'rocprim::invoke_result_binary_op' instead!")]] -= rocprim::invoke_result_binary_op; - -template -using match_result_type [[deprecated("Use 'rocprim::invoke_result_binary_op_t' instead!")]] -= rocprim::invoke_result_binary_op_t; - -} // end namespace detail -END_ROCPRIM_NAMESPACE - -#endif // ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ diff --git a/rocprim/include/rocprim/detail/merge_path.hpp b/rocprim/include/rocprim/detail/merge_path.hpp index eb61fb627..25cd3b312 100644 --- a/rocprim/include/rocprim/detail/merge_path.hpp +++ b/rocprim/include/rocprim/detail/merge_path.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -55,16 +55,22 @@ struct range_t }; template -ROCPRIM_HOST_DEVICE ROCPRIM_INLINE OffsetT merge_path(KeysInputIterator1 keys_input1, - KeysInputIterator2 keys_input2, - const OffsetT input1_size, - const OffsetT input2_size, - const OffsetT diag, - BinaryFunction compare_function) +ROCPRIM_HOST_DEVICE ROCPRIM_INLINE +OffsetT merge_path(KeysInputIterator1 keys_input1, + KeysInputIterator2 keys_input2, + const OffsetT input1_size, + const OffsetT input2_size, + const OffsetT diag, + BinaryFunction compare_function) { using key_type_1 = typename std::iterator_traits::value_type; using key_type_2 = typename std::iterator_traits::value_type; + static_assert( + std::is_convertible_v, + bool>, + "Comparison operator must be convertible to bool"); + OffsetT begin = diag < input2_size ? 0u : diag - input2_size; OffsetT end = min(diag, input1_size); @@ -100,8 +106,8 @@ void serial_merge(KeyType* keys_shared, OutputFunction output_function) { // Pre condition, we're including some edge cases too. - if (!AllowUnsafe && range.begin1 > range.end1 && range.begin2 > range.end2) - return; // don't do anything, we have invalid inputs + if(!AllowUnsafe && range.begin1 > range.end1 && range.begin2 > range.end2) + return; // don't do anything, we have invalid inputs // More descriptive names for ranges: auto idx_a = range.begin1; diff --git a/rocprim/include/rocprim/detail/various.hpp b/rocprim/include/rocprim/detail/various.hpp index fb5d58863..40c14a31c 100644 --- a/rocprim/include/rocprim/detail/various.hpp +++ b/rocprim/include/rocprim/detail/various.hpp @@ -329,25 +329,6 @@ constexpr std::add_const_t* as_const_ptr(T* ptr) return ptr; } -template -ROCPRIM_HOST_DEVICE inline void - for_each_in_tuple_impl(Tuple&& t, Function&& f, ::rocprim::index_sequence) -{ - int swallow[] - = {(std::forward(f)(::rocprim::get(std::forward(t))), 0)...}; - (void)swallow; -} - -template -ROCPRIM_HOST_DEVICE inline auto for_each_in_tuple(Tuple&& t, Function&& f) - -> void_t>> -{ - static constexpr size_t size = tuple_size>::value; - for_each_in_tuple_impl(std::forward(t), - std::forward(f), - ::rocprim::make_index_sequence()); -} - /// \brief Reinterprets the pointer as another type and increments it to match the alignment of /// the new type. /// @@ -395,22 +376,10 @@ ROCPRIM_HOST_DEVICE ROCPRIM_INLINE DstPtr cast_align_down(Src* pointer) } template -ROCPRIM_HOST_DEVICE auto bit_cast(const Source& source) - -> std::enable_if_t::value - && std::is_trivially_copyable::value, - Destination> +ROCPRIM_INLINE ROCPRIM_HOST_DEVICE +auto bit_cast(const Source& source) { -#if defined(__has_builtin) && __has_builtin(__builtin_bit_cast) - return __builtin_bit_cast(Destination, source); -#else - static_assert( - std::is_trivially_constructable::value, - "Fallback implementation of bit_cast requires Destination to be trivially constructible"); - Destination dest; - memcpy(&dest, &source, sizeof(Destination)); - return dest; -#endif + return ::rocprim::traits::radix_key_codec::bit_cast(source); } template diff --git a/rocprim/include/rocprim/device/config_types.hpp b/rocprim/include/rocprim/device/config_types.hpp index 86778f80b..612102964 100644 --- a/rocprim/include/rocprim/device/config_types.hpp +++ b/rocprim/include/rocprim/device/config_types.hpp @@ -255,7 +255,7 @@ constexpr target_arch get_target_arch_from_name(const char* const arch_name, con */ constexpr target_arch device_target_arch() { -#if defined(__amdgcn_processor__) +#if defined(__amdgcn_processor__) && !defined(ROCPRIM_EXPERIMENTAL_SPIRV) // The terminating zero is not counted in the length of the string return get_target_arch_from_name(__amdgcn_processor__, sizeof(__amdgcn_processor__) - sizeof('\0')); @@ -265,10 +265,12 @@ constexpr target_arch device_target_arch() } template -auto dispatch_target_arch(const target_arch target_arch) +auto dispatch_target_arch([[maybe_unused]] const target_arch target_arch) { +#if !defined(ROCPRIM_EXPERIMENTAL_SPIRV) switch(target_arch) { + case target_arch::unknown: return Config::template architecture_config::params; case target_arch::gfx803: @@ -296,6 +298,7 @@ auto dispatch_target_arch(const target_arch target_arch) case target_arch::invalid: assert(false && "Invalid target architecture selected at runtime."); } +#endif return Config::template architecture_config::params; } diff --git a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp index 039af66b8..9424ec2e8 100644 --- a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_DIFFERENCE_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp index 6eb95f43d..5416e638b 100644 --- a/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_DIFFERENCE_INPLACE_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp b/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp index 16951744f..11f92a579 100644 --- a/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_FIND_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_binary_search.hpp b/rocprim/include/rocprim/device/detail/config/device_binary_search.hpp index bf34e84be..f45dd7419 100644 --- a/rocprim/include/rocprim/device/detail/config/device_binary_search.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_binary_search.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_BINARY_SEARCH_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_find_first_of.hpp b/rocprim/include/rocprim/device/detail/config/device_find_first_of.hpp index 9cf010f28..768d68e51 100644 --- a/rocprim/include/rocprim/device/detail/config/device_find_first_of.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_find_first_of.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_FIND_FIRST_OF_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_histogram.hpp b/rocprim/include/rocprim/device/detail/config/device_histogram.hpp index 70cb3da03..03dde2899 100644 --- a/rocprim/include/rocprim/device/detail/config/device_histogram.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_histogram.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_HISTOGRAM_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_lower_bound.hpp b/rocprim/include/rocprim/device/detail/config/device_lower_bound.hpp index 263d8cd19..76ebeda97 100644 --- a/rocprim/include/rocprim/device/detail/config/device_lower_bound.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_lower_bound.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_LOWER_BOUND_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_merge.hpp b/rocprim/include/rocprim/device/detail/config/device_merge.hpp index caada274c..623aece8b 100644 --- a/rocprim/include/rocprim/device/detail/config/device_merge.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_merge.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_merge.hpp b/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_merge.hpp index 07948451b..1d5b4b410 100644 --- a/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_merge.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_merge.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_SORT_BLOCK_MERGE_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_sort.hpp b/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_sort.hpp index 5e49039c4..5cb53da3a 100644 --- a/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_sort.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_sort.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_SORT_BLOCK_SORT_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_partition_flag.hpp b/rocprim/include/rocprim/device/detail/config/device_partition_flag.hpp index 62e355c89..0c0a1b540 100644 --- a/rocprim/include/rocprim/device/detail/config/device_partition_flag.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_partition_flag.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_FLAG_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_partition_predicate.hpp b/rocprim/include/rocprim/device/detail/config/device_partition_predicate.hpp index d69b4ad7c..80658c116 100644 --- a/rocprim/include/rocprim/device/detail/config/device_partition_predicate.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_partition_predicate.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_PREDICATE_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_partition_three_way.hpp b/rocprim/include/rocprim/device/detail/config/device_partition_three_way.hpp index 1347cda10..eac8d6632 100644 --- a/rocprim/include/rocprim/device/detail/config/device_partition_three_way.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_partition_three_way.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_THREE_WAY_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_partition_two_way_flag.hpp b/rocprim/include/rocprim/device/detail/config/device_partition_two_way_flag.hpp index f72284cab..82dfb30f2 100644 --- a/rocprim/include/rocprim/device/detail/config/device_partition_two_way_flag.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_partition_two_way_flag.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_TWO_WAY_FLAG_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_partition_two_way_predicate.hpp b/rocprim/include/rocprim/device/detail/config/device_partition_two_way_predicate.hpp index 95a523fcb..a6fdba90c 100644 --- a/rocprim/include/rocprim/device/detail/config/device_partition_two_way_predicate.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_partition_two_way_predicate.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_TWO_WAY_PREDICATE_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp b/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp index 7d7a01739..556c91587 100644 --- a/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_RADIX_SORT_BLOCK_SORT_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp b/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp index 843f94487..0f549c9fa 100644 --- a/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_RADIX_SORT_ONESWEEP_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_reduce.hpp b/rocprim/include/rocprim/device/detail/config/device_reduce.hpp index d79ab2dcd..3ce0a9cb7 100644 --- a/rocprim/include/rocprim/device/detail/config/device_reduce.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_reduce.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_REDUCE_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_run_length_encode_non_trivial.hpp b/rocprim/include/rocprim/device/detail/config/device_run_length_encode_non_trivial.hpp index fe77ab86a..a453649ea 100644 --- a/rocprim/include/rocprim/device/detail/config/device_run_length_encode_non_trivial.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_run_length_encode_non_trivial.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_RUN_LENGTH_ENCODE_NON_TRIVIAL_RUNS_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_scan.hpp b/rocprim/include/rocprim/device/detail/config/device_scan.hpp index 3d3990537..c6de6d57d 100644 --- a/rocprim/include/rocprim/device/detail/config/device_scan.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_scan.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SCAN_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp b/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp index 479bac10e..62cff0638 100644 --- a/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SCAN_BY_KEY_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_search_n.hpp b/rocprim/include/rocprim/device/detail/config/device_search_n.hpp new file mode 100644 index 000000000..6e336f0f5 --- /dev/null +++ b/rocprim/include/rocprim/device/detail/config/device_search_n.hpp @@ -0,0 +1,511 @@ +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEARCH_N_HPP_ +#define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEARCH_N_HPP_ + +#include "../../../config.hpp" +#include "../../../type_traits.hpp" +#include "../../config_types.hpp" +#include "../device_config_helper.hpp" + +#include + +/* DO NOT EDIT THIS FILE + * This file is automatically generated by `/scripts/autotune/create_optimization.py`. + * so most likely you want to edit rocprim/device/device_(algo)_config.hpp + */ + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +struct default_search_n_config : default_search_n_config_base::type +{}; + +// Based on data_type = double +template +struct default_search_n_config< + static_cast(target_arch::gfx1030), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))>> : search_n_config<512, 4, 8> +{}; + +// Based on data_type = float +template +struct default_search_n_config< + static_cast(target_arch::gfx1030), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))>> : search_n_config<256, 4, 8> +{}; + +// Based on data_type = rocprim::half +template +struct default_search_n_config(target_arch::gfx1030), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2))>> + : search_n_config<1024, 8, 16> +{}; + +// Based on data_type = rocprim::int128_t +template +struct default_search_n_config< + static_cast(target_arch::gfx1030), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> + : search_n_config<1024, 1, 8> +{}; + +// Based on data_type = int64_t +template +struct default_search_n_config< + static_cast(target_arch::gfx1030), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> + : search_n_config<512, 4, 16> +{}; + +// Based on data_type = int +template +struct default_search_n_config< + static_cast(target_arch::gfx1030), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> + : search_n_config<256, 4, 8> +{}; + +// Based on data_type = short +template +struct default_search_n_config< + static_cast(target_arch::gfx1030), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> + : search_n_config<1024, 8, 16> +{}; + +// Based on data_type = int8_t +template +struct default_search_n_config(target_arch::gfx1030), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1))>> + : search_n_config<1024, 16, 16> +{}; + +// Based on data_type = double +template +struct default_search_n_config< + static_cast(target_arch::gfx1100), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))>> : search_n_config<256, 4, 8> +{}; + +// Based on data_type = float +template +struct default_search_n_config< + static_cast(target_arch::gfx1100), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))>> : search_n_config<256, 4, 8> +{}; + +// Based on data_type = rocprim::half +template +struct default_search_n_config(target_arch::gfx1100), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2))>> + : search_n_config<1024, 16, 16> +{}; + +// Based on data_type = rocprim::int128_t +template +struct default_search_n_config< + static_cast(target_arch::gfx1100), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> + : search_n_config<256, 1, 8> +{}; + +// Based on data_type = int64_t +template +struct default_search_n_config< + static_cast(target_arch::gfx1100), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> + : search_n_config<256, 4, 8> +{}; + +// Based on data_type = int +template +struct default_search_n_config< + static_cast(target_arch::gfx1100), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> + : search_n_config<64, 8, 16> +{}; + +// Based on data_type = short +template +struct default_search_n_config< + static_cast(target_arch::gfx1100), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> + : search_n_config<1024, 16, 16> +{}; + +// Based on data_type = int8_t +template +struct default_search_n_config(target_arch::gfx1100), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1))>> + : search_n_config<1024, 16, 16> +{}; + +// Based on data_type = double +template +struct default_search_n_config< + static_cast(target_arch::gfx906), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))>> : search_n_config<128, 2, 8> +{}; + +// Based on data_type = float +template +struct default_search_n_config< + static_cast(target_arch::gfx906), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))>> : search_n_config<256, 4, 8> +{}; + +// Based on data_type = rocprim::half +template +struct default_search_n_config(target_arch::gfx906), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2))>> + : search_n_config<256, 4, 8> +{}; + +// Based on data_type = rocprim::int128_t +template +struct default_search_n_config< + static_cast(target_arch::gfx906), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> + : search_n_config<256, 1, 4> +{}; + +// Based on data_type = int64_t +template +struct default_search_n_config< + static_cast(target_arch::gfx906), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> + : search_n_config<128, 2, 4> +{}; + +// Based on data_type = int +template +struct default_search_n_config< + static_cast(target_arch::gfx906), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> + : search_n_config<256, 2, 8> +{}; + +// Based on data_type = short +template +struct default_search_n_config< + static_cast(target_arch::gfx906), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> + : search_n_config<256, 4, 8> +{}; + +// Based on data_type = int8_t +template +struct default_search_n_config(target_arch::gfx906), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1))>> + : search_n_config<256, 4, 8> +{}; + +// Based on data_type = double +template +struct default_search_n_config< + static_cast(target_arch::gfx908), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))>> : search_n_config<1024, 2, 4> +{}; + +// Based on data_type = float +template +struct default_search_n_config< + static_cast(target_arch::gfx908), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))>> : search_n_config<256, 2, 8> +{}; + +// Based on data_type = rocprim::half +template +struct default_search_n_config(target_arch::gfx908), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2))>> + : search_n_config<256, 4, 12> +{}; + +// Based on data_type = rocprim::int128_t +template +struct default_search_n_config< + static_cast(target_arch::gfx908), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> + : search_n_config<512, 1, 12> +{}; + +// Based on data_type = int64_t +template +struct default_search_n_config< + static_cast(target_arch::gfx908), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> + : search_n_config<1024, 2, 12> +{}; + +// Based on data_type = int +template +struct default_search_n_config< + static_cast(target_arch::gfx908), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> + : search_n_config<256, 2, 8> +{}; + +// Based on data_type = short +template +struct default_search_n_config< + static_cast(target_arch::gfx908), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> + : search_n_config<256, 4, 12> +{}; + +// Based on data_type = int8_t +template +struct default_search_n_config(target_arch::gfx908), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1))>> + : search_n_config<256, 4, 8> +{}; + +// Based on data_type = double +template +struct default_search_n_config< + static_cast(target_arch::gfx90a), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))>> : search_n_config<128, 2, 8> +{}; + +// Based on data_type = float +template +struct default_search_n_config< + static_cast(target_arch::gfx90a), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))>> : search_n_config<256, 2, 8> +{}; + +// Based on data_type = rocprim::half +template +struct default_search_n_config(target_arch::gfx90a), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2))>> + : search_n_config<128, 4, 8> +{}; + +// Based on data_type = rocprim::int128_t +template +struct default_search_n_config< + static_cast(target_arch::gfx90a), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> + : search_n_config<128, 1, 4> +{}; + +// Based on data_type = int64_t +template +struct default_search_n_config< + static_cast(target_arch::gfx90a), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> + : search_n_config<128, 2, 8> +{}; + +// Based on data_type = int +template +struct default_search_n_config< + static_cast(target_arch::gfx90a), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> + : search_n_config<256, 2, 8> +{}; + +// Based on data_type = short +template +struct default_search_n_config< + static_cast(target_arch::gfx90a), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> + : search_n_config<128, 4, 8> +{}; + +// Based on data_type = int8_t +template +struct default_search_n_config(target_arch::gfx90a), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1))>> + : search_n_config<128, 4, 8> +{}; + +// Based on data_type = double +template +struct default_search_n_config< + static_cast(target_arch::unknown), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))>> : search_n_config<1024, 2, 4> +{}; + +// Based on data_type = float +template +struct default_search_n_config< + static_cast(target_arch::unknown), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))>> : search_n_config<256, 2, 4> +{}; + +// Based on data_type = rocprim::half +template +struct default_search_n_config(target_arch::unknown), + data_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2))>> + : search_n_config<256, 4, 12> +{}; + +// Based on data_type = rocprim::int128_t +template +struct default_search_n_config< + static_cast(target_arch::unknown), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> + : search_n_config<512, 1, 12> +{}; + +// Based on data_type = int64_t +template +struct default_search_n_config< + static_cast(target_arch::unknown), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> + : search_n_config<1024, 2, 16> +{}; + +// Based on data_type = int +template +struct default_search_n_config< + static_cast(target_arch::unknown), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> + : search_n_config<256, 2, 8> +{}; + +// Based on data_type = short +template +struct default_search_n_config< + static_cast(target_arch::unknown), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> + : search_n_config<256, 4, 16> +{}; + +// Based on data_type = int8_t +template +struct default_search_n_config(target_arch::unknown), + data_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1))>> + : search_n_config<256, 4, 8> +{}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEARCH_N_HPP_ \ No newline at end of file diff --git a/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp index 1003b35e7..bd5504b35 100644 --- a/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_RADIX_SORT_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" @@ -42,7 +42,7 @@ namespace detail { template -struct default_segmented_radix_sort_config : default_segmented_radix_sort_config_base<6, 4>::type +struct default_segmented_radix_sort_config : default_segmented_radix_sort_config_base<6>::type {}; // Based on key_type = double, value_type = rocprim::int128_t @@ -55,7 +55,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -75,7 +74,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -94,7 +92,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -113,7 +110,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -132,7 +128,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -151,7 +146,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -169,7 +163,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -189,7 +182,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 2, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -208,7 +200,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 2, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -227,7 +218,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 2, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -246,7 +236,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 2, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -265,7 +254,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 5, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -282,7 +270,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -300,7 +287,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -318,7 +304,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -337,7 +322,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 6, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -355,7 +339,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -374,7 +357,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -393,7 +375,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -411,7 +392,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -430,7 +410,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -449,7 +428,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -468,7 +446,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -488,7 +465,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -507,7 +483,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -526,7 +501,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -545,7 +519,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -564,7 +537,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -583,7 +555,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -602,7 +573,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -620,7 +590,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -640,7 +609,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -659,7 +627,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -678,7 +645,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -697,7 +663,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 2, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -716,7 +681,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 2, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -734,7 +698,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -754,7 +717,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -773,7 +735,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -792,7 +753,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 6, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -811,7 +771,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -830,7 +789,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -847,7 +805,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -866,7 +823,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -884,7 +840,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -902,7 +857,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -921,7 +875,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 6, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -939,7 +892,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -957,7 +909,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -976,7 +927,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -995,7 +945,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -1014,7 +963,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -1033,7 +981,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -1053,7 +1000,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1071,7 +1017,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1090,7 +1035,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1109,7 +1053,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1128,7 +1071,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1147,7 +1089,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1167,7 +1108,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1184,7 +1124,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1202,7 +1141,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1220,7 +1158,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1239,7 +1176,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1257,7 +1193,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -1276,7 +1211,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1295,7 +1229,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1313,7 +1246,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -1332,7 +1264,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -1351,7 +1282,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -1370,7 +1300,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -1390,7 +1319,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1409,7 +1337,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1428,7 +1355,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1446,7 +1372,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -1465,7 +1390,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -1484,7 +1408,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -1504,7 +1427,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -1522,7 +1444,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1542,7 +1463,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1561,7 +1481,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1579,7 +1498,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1598,7 +1516,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1618,7 +1535,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1636,7 +1552,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1655,7 +1570,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1674,7 +1588,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1693,7 +1606,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1712,7 +1624,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1732,7 +1643,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 4>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -1749,7 +1659,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -1767,7 +1676,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1785,7 +1693,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -1803,7 +1710,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -1823,7 +1729,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1841,7 +1746,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -1860,7 +1764,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -1879,7 +1782,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -1897,7 +1799,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -1916,7 +1817,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -1935,7 +1835,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -1954,7 +1853,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -1972,7 +1870,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -1990,7 +1887,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -2009,7 +1905,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -2028,7 +1923,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -2047,7 +1941,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -2066,7 +1959,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -2085,7 +1977,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -2104,7 +1995,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -2123,7 +2013,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -2143,7 +2032,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 2, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -2162,7 +2050,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -2181,7 +2068,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -2200,7 +2086,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 2, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -2219,7 +2104,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -2237,7 +2121,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -2257,7 +2140,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2276,7 +2158,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -2295,7 +2176,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -2314,7 +2194,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -2333,7 +2212,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2350,7 +2228,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -2368,7 +2245,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -2386,7 +2262,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -2405,7 +2280,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 6, - 4, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<16, 4, 256, 5, 32, 8, 256>, @@ -2423,7 +2297,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -2442,7 +2315,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2461,7 +2333,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< 6, - 4, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2479,7 +2350,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -2498,7 +2368,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -2517,7 +2386,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -2536,7 +2404,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -2556,7 +2423,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2575,7 +2441,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2594,7 +2459,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -2613,7 +2477,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -2632,7 +2495,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -2651,7 +2513,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -2670,7 +2531,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -2688,7 +2548,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -2708,7 +2567,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2727,7 +2585,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -2746,7 +2603,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -2765,7 +2621,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -2784,7 +2639,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2802,7 +2656,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -2822,7 +2675,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2841,7 +2693,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2860,7 +2711,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2879,7 +2729,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2898,7 +2747,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -2915,7 +2763,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -2934,7 +2781,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 4>, typename std::conditional<1, WarpSortConfig<16, 2, 256, 5, 32, 4, 256>, @@ -2951,7 +2797,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config<6, - 4, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<4, 4, 256, 5, 8, 8, 256>, @@ -2969,7 +2814,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 6, - 4, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<16, 4, 256, 5, 32, 8, 256>, @@ -2988,7 +2832,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 6, - 4, kernel_config<256, 4>, typename std::conditional<1, WarpSortConfig<8, 2, 256, 5, 16, 4, 256>, @@ -3006,7 +2849,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3024,7 +2866,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -3044,7 +2885,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3063,7 +2903,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3082,7 +2921,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -3101,7 +2939,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -3120,7 +2957,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -3138,7 +2974,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -3158,7 +2993,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3177,7 +3011,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3196,7 +3029,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3215,7 +3047,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3234,7 +3065,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3251,7 +3081,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -3269,7 +3098,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -3287,7 +3115,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -3306,7 +3133,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3324,7 +3150,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -3343,7 +3168,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3362,7 +3186,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3380,7 +3203,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -3399,7 +3221,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -3418,7 +3239,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -3437,7 +3257,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -3457,7 +3276,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3476,7 +3294,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3495,7 +3312,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3514,7 +3330,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -3533,7 +3348,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -3552,7 +3366,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -3571,7 +3384,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -3589,7 +3401,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -3609,7 +3420,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3628,7 +3438,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -3647,7 +3456,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3666,7 +3474,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -3685,7 +3492,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3703,7 +3509,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -3723,7 +3528,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3742,7 +3546,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3761,7 +3564,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3780,7 +3582,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3799,7 +3600,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3816,7 +3616,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -3835,7 +3634,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3853,7 +3651,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3871,7 +3668,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3890,7 +3686,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3908,7 +3703,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3926,7 +3720,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -3946,7 +3739,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3965,7 +3757,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -3984,7 +3775,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -4003,7 +3793,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -4022,7 +3811,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -4040,7 +3828,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4060,7 +3847,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4079,7 +3865,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4098,7 +3883,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4117,7 +3901,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4136,7 +3919,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4153,7 +3935,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4171,7 +3952,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4189,7 +3969,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4208,7 +3987,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4226,7 +4004,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4245,7 +4022,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4263,7 +4039,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4282,7 +4057,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4301,7 +4075,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4320,7 +4093,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4339,7 +4111,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4359,7 +4130,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4378,7 +4148,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4397,7 +4166,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4416,7 +4184,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -4435,7 +4202,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -4454,7 +4220,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -4473,7 +4238,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -4491,7 +4255,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4511,7 +4274,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4530,7 +4292,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -4549,7 +4310,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4568,7 +4328,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -4587,7 +4346,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4605,7 +4363,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4625,7 +4382,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4644,7 +4400,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4663,7 +4418,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4682,7 +4436,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4701,7 +4454,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4718,7 +4470,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -4737,7 +4488,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4755,7 +4505,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4773,7 +4522,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4792,7 +4540,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4810,7 +4557,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4828,7 +4574,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4848,7 +4593,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4867,7 +4611,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4886,7 +4629,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -4905,7 +4647,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -4924,7 +4665,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -4942,7 +4682,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -4962,7 +4701,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -4981,7 +4719,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5000,7 +4737,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5019,7 +4755,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5038,7 +4773,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5055,7 +4789,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5073,7 +4806,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -5091,7 +4823,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -5110,7 +4841,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5128,7 +4858,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -5147,7 +4876,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5166,7 +4894,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5184,7 +4911,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5203,7 +4929,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5222,7 +4947,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5241,7 +4965,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5261,7 +4984,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5280,7 +5002,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< 7, - 6, kernel_config<128, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5299,7 +5020,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5318,7 +5038,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -5337,7 +5056,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -5356,7 +5074,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -5375,7 +5092,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -5393,7 +5109,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5413,7 +5128,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5432,7 +5146,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, @@ -5451,7 +5164,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5470,7 +5182,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 4, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, @@ -5489,7 +5200,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5507,7 +5217,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5527,7 +5236,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5546,7 +5254,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5565,7 +5272,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5584,7 +5290,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5603,7 +5308,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5620,7 +5324,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5639,7 +5342,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5657,7 +5359,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5675,7 +5376,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5694,7 +5394,6 @@ struct default_segmented_radix_sort_config< && (!std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5712,7 +5411,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 7, - 6, kernel_config<256, 17>, typename std::conditional<1, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, @@ -5730,7 +5428,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5749,7 +5446,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -5768,7 +5464,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5787,7 +5482,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -5806,7 +5500,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5826,7 +5519,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -5844,7 +5536,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -5863,7 +5554,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5882,7 +5572,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5901,7 +5590,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5920,7 +5608,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5940,7 +5627,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -5957,7 +5643,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -5975,7 +5660,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -5993,7 +5677,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6011,7 +5694,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6030,7 +5712,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6049,7 +5730,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -6067,7 +5747,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6086,7 +5765,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6105,7 +5783,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6124,7 +5801,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -6143,7 +5819,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -6163,7 +5838,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 4>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -6181,7 +5855,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6200,7 +5873,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -6219,7 +5891,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6238,7 +5909,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 16>, typename std::conditional<1, @@ -6257,7 +5927,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6277,7 +5946,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -6295,7 +5963,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6314,7 +5981,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6333,7 +5999,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6352,7 +6017,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6371,7 +6035,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6391,7 +6054,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -6409,7 +6071,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -6428,7 +6089,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6447,7 +6107,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6466,7 +6125,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6485,7 +6143,6 @@ struct default_segmented_radix_sort_config< && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6505,7 +6162,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 16>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, @@ -6522,7 +6178,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 4>, typename std::conditional<1, @@ -6540,7 +6195,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6558,7 +6212,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6576,7 +6229,6 @@ struct default_segmented_radix_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6595,7 +6247,6 @@ struct default_segmented_radix_sort_config< && (sizeof(value_type) <= 1) && (!std::is_same::value))>> : segmented_radix_sort_config< - 8, 8, kernel_config<256, 8>, typename std::conditional<1, @@ -6614,7 +6265,6 @@ struct default_segmented_radix_sort_config< && (std::is_same::value))>> : segmented_radix_sort_config< 8, - 0, kernel_config<256, 8>, typename std::conditional<1, WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, diff --git a/rocprim/include/rocprim/device/detail/config/device_select_flag.hpp b/rocprim/include/rocprim/device/detail/config/device_select_flag.hpp index 2c4e035e0..7ab24f63a 100644 --- a/rocprim/include/rocprim/device/detail/config/device_select_flag.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_select_flag.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_FLAG_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_select_predicate.hpp b/rocprim/include/rocprim/device/detail/config/device_select_predicate.hpp index dd69d9cb2..9efc5124b 100644 --- a/rocprim/include/rocprim/device/detail/config/device_select_predicate.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_select_predicate.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATE_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp b/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp index f0d004d7e..1b8faafb2 100644 --- a/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATED_FLAG_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_select_unique.hpp b/rocprim/include/rocprim/device/detail/config/device_select_unique.hpp index 670ae5da6..6088fe77e 100644 --- a/rocprim/include/rocprim/device/detail/config/device_select_unique.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_select_unique.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_UNIQUE_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_select_unique_by_key.hpp b/rocprim/include/rocprim/device/detail/config/device_select_unique_by_key.hpp index b0515386f..958e0c25f 100644 --- a/rocprim/include/rocprim/device/detail/config/device_select_unique_by_key.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_select_unique_by_key.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_UNIQUE_BY_KEY_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/config/device_transform.hpp b/rocprim/include/rocprim/device/detail/config/device_transform.hpp index 7ceda9d24..7e2f7f865 100644 --- a/rocprim/include/rocprim/device/detail/config/device_transform.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_transform.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_TRANSFORM_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" @@ -656,7 +656,7 @@ struct default_transform_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<1024, 4> + : transform_config<512, 4> {}; // Based on value_type = float @@ -675,7 +675,7 @@ struct default_transform_config< static_cast(target_arch::gfx942), value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<128, 16> + && (sizeof(value_type) <= 2))>> : transform_config<256, 8> {}; // Based on value_type = rocprim::int128_t @@ -724,7 +724,7 @@ struct default_transform_config< static_cast(target_arch::gfx942), value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<256, 16> + && (sizeof(value_type) <= 1))>> : transform_config<1024, 8> {}; } // end namespace detail diff --git a/rocprim/include/rocprim/device/detail/config/device_transform_pointer.hpp b/rocprim/include/rocprim/device/detail/config/device_transform_pointer.hpp new file mode 100644 index 000000000..91ee63d60 --- /dev/null +++ b/rocprim/include/rocprim/device/detail/config/device_transform_pointer.hpp @@ -0,0 +1,615 @@ +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_TRANSFORM_POINTER_HPP_ +#define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_TRANSFORM_POINTER_HPP_ + +#include "../../../config.hpp" +#include "../../../type_traits.hpp" +#include "../../config_types.hpp" +#include "../device_config_helper.hpp" + +#include + +/* DO NOT EDIT THIS FILE + * This file is automatically generated by `/scripts/autotune/create_optimization.py`. + * so most likely you want to edit rocprim/device/device_(algo)_config.hpp + */ + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +struct default_transform_pointer_config : default_transform_pointer_config_base::type +{}; + +// Based on value_type = double +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1030), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = float +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1030), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = rocprim::half +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1030), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = rocprim::int128_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1030), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int64_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1030), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1030), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = short +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1030), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int8_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1030), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))>> + : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = double +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1100), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<512, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = float +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1100), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = rocprim::half +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1100), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))>> + : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = rocprim::int128_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1100), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = int64_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1100), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<512, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = int +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1100), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = short +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1100), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int8_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx1100), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))>> + : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = double +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx906), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = float +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx906), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = rocprim::half +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx906), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))>> + : transform_pointer_config<512, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = rocprim::int128_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx906), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int64_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx906), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<512, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx906), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = short +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx906), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : transform_pointer_config<1024, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int8_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx906), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))>> + : transform_pointer_config<512, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = double +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx908), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = float +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx908), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<128, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = rocprim::half +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx908), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))>> + : transform_pointer_config<128, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = rocprim::int128_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx908), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> + : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int64_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx908), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx908), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<128, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = short +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx908), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : transform_pointer_config<128, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int8_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx908), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))>> + : transform_pointer_config<128, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = double +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx90a), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = float +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx90a), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = rocprim::half +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx90a), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))>> + : transform_pointer_config<1024, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = rocprim::int128_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx90a), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> + : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = int64_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx90a), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx90a), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = short +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx90a), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : transform_pointer_config<1024, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int8_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx90a), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))>> + : transform_pointer_config<1024, 16, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = double +template +struct default_transform_pointer_config< + static_cast(target_arch::unknown), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = float +template +struct default_transform_pointer_config< + static_cast(target_arch::unknown), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<128, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = rocprim::half +template +struct default_transform_pointer_config< + static_cast(target_arch::unknown), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))>> + : transform_pointer_config<128, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = rocprim::int128_t +template +struct default_transform_pointer_config< + static_cast(target_arch::unknown), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> + : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int64_t +template +struct default_transform_pointer_config< + static_cast(target_arch::unknown), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int +template +struct default_transform_pointer_config< + static_cast(target_arch::unknown), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<128, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = short +template +struct default_transform_pointer_config< + static_cast(target_arch::unknown), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : transform_pointer_config<128, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> +{}; + +// Based on value_type = int8_t +template +struct default_transform_pointer_config< + static_cast(target_arch::unknown), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))>> + : transform_pointer_config<128, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = double +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx942), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = float +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx942), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<256, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = rocprim::half +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx942), + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))>> + : transform_pointer_config<256, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = rocprim::int128_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx942), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> + : transform_pointer_config<256, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = int64_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx942), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : transform_pointer_config<256, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = int +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx942), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : transform_pointer_config<256, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = short +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx942), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : transform_pointer_config<256, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +// Based on value_type = int8_t +template +struct default_transform_pointer_config< + static_cast(target_arch::gfx942), + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))>> + : transform_pointer_config<256, 16, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> +{}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_TRANSFORM_POINTER_HPP_ diff --git a/rocprim/include/rocprim/device/detail/config/device_upper_bound.hpp b/rocprim/include/rocprim/device/detail/config/device_upper_bound.hpp index ae97db6ec..2cc38c24f 100644 --- a/rocprim/include/rocprim/device/detail/config/device_upper_bound.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_upper_bound.hpp @@ -22,7 +22,7 @@ #define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_UPPER_BOUND_HPP_ #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/rocprim/include/rocprim/device/detail/device_adjacent_difference.hpp b/rocprim/include/rocprim/device/detail/device_adjacent_difference.hpp index de69dfe21..21f982226 100644 --- a/rocprim/include/rocprim/device/detail/device_adjacent_difference.hpp +++ b/rocprim/include/rocprim/device/detail/device_adjacent_difference.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -180,7 +180,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void adjacent_difference_kernel_impl( const std::size_t starting_block) { using input_type = typename std::iterator_traits::value_type; - using output_type = rocprim::invoke_result_binary_op_t; + using output_type = ::rocprim::accumulator_t; static constexpr adjacent_difference_config_params params = device_params(); diff --git a/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/rocprim/include/rocprim/device/detail/device_config_helper.hpp index 7be7bf4bb..36d7e2042 100644 --- a/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -445,6 +445,7 @@ struct transform_config_tag struct transform_config_params { kernel_config_params kernel_config{}; + cache_load_modifier load_type; }; } // namespace detail @@ -479,11 +480,8 @@ struct segmented_radix_sort_config_params { /// \brief Kernel start parameters. kernel_config_params kernel_config{}; - /// \brief Number of bits in long iterations. - unsigned int long_radix_bits = 0; - /// \brief Number of bits in short iterations. - /// \deprecated The short radix bits parameter is no longer used and will be removed in a future version. - unsigned int short_radix_bits = 0; + /// \brief Number of bits in iterations. + unsigned int radix_bits = 0; /// \brief If set to \p true, warp sort can be used to sort the small segments, even if no partitioning happens. bool enable_unpartitioned_warp_sort = true; /// \brief Warp sort config params @@ -569,22 +567,19 @@ struct DisabledWarpSortConfig //// \brief Configuration of device-level segmented radix sort operation. /// /// Radix sort is excecuted in a few iterations (passes) depending on total number of bits to be sorted -/// (`begin_bit` and `end_bit`), each iteration sorts either `LongRadixBits` or `ShortRadixBits` bits +/// (`begin_bit` and `end_bit`), each iteration sorts `RadixBits` bits /// chosen to cover whole bit range in optimal way. /// -/// For example, if `LongRadixBits` is 7, `ShortRadixBits` is 6, `begin_bit` is 0 and `end_bit` is 32 -/// there will be 5 iterations: 7 + 7 + 6 + 6 + 6 = 32 bits. +/// For example, if `RadixBits` is 7, `begin_bit` is 0 and `end_bit` is 32 +/// there will be 5 iterations: 7 + 7 + 7 + 7 + 4 (still sorting with 7 bits) = 32 bits. /// /// If a segment's element count is low ( <= warp_sort_config::items_per_thread * warp_sort_config::logical_warp_size ), /// it is sorted by a special warp-level sorting method. /// -/// \tparam LongRadixBits number of bits in long iterations. -/// \tparam ShortRadixBits number of bits in short iterations, must be equal to or less than `LongRadixBits`. -/// Deprecated and no longer used. +/// \tparam RadixBits number of bits in long iterations. /// \tparam SortConfig configuration of radix sort kernel. Must be `kernel_config`. /// \tparam WarpSortConfig configuration of the warp sort that is used on the short segments. -template @@ -594,12 +589,8 @@ struct segmented_radix_sort_config : public detail::segmented_radix_sort_config_ using tag = detail::segmented_radix_sort_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS - /// \brief Number of bits in long iterations. - static constexpr unsigned int long_radix_bits = LongRadixBits; - - /// \brief Number of bits in short iterations. - /// \deprecated The short radix bits parameter is no longer used and will be removed in a future version. - static constexpr unsigned int short_radix_bits = ShortRadixBits; + /// \brief Number of bits in iterations. + static constexpr unsigned int radix_bits = RadixBits; /// \brief Number of threads in a block. static constexpr unsigned int block_size = SortConfig::block_size; @@ -618,8 +609,7 @@ struct segmented_radix_sort_config : public detail::segmented_radix_sort_config_ constexpr segmented_radix_sort_config() : detail::segmented_radix_sort_config_params{ SortConfig(), - LongRadixBits, - ShortRadixBits, + RadixBits, EnableUnpartitionedWarpSort, {warp_sort_config::partitioning_allowed, warp_sort_config::logical_warp_size_small, @@ -638,16 +628,14 @@ namespace detail { /// \brief Default segmented_radix_sort kernel configurations, such that the maximum shared memory is not exceeded. /// -/// \tparam LongRadixBits Long bits used during the sorting. -/// \tparam ShortRadixBits Short bits used during the sorting. +/// \tparam RadixBits Bits used during the sorting. /// \tparam ItemsPerThread Items per thread when type Key has size 1. -template +template struct default_segmented_radix_sort_config_base { static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( sizeof(unsigned int) + sizeof(unsigned int), sizeof(int)); - using type = segmented_radix_sort_config, WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, true>; @@ -659,9 +647,11 @@ struct default_segmented_radix_sort_config_base /// \tparam BlockSize Number of threads in a block. /// \tparam ItemsPerThread Number of items processed by each thread. /// \tparam SizeLimit Limit on the number of items for a single kernel launch. -template +/// \tparam LoadType The type of thread_load used. +template struct transform_config : public detail::transform_config_params { /// \brief Identifies the algorithm associated to the config. @@ -674,12 +664,51 @@ struct transform_config : public detail::transform_config_params /// \brief Number of items processed by each thread. static constexpr unsigned int items_per_thread = ItemsPerThread; + /// \brief The default load is being used. + static constexpr cache_load_modifier load_type = LoadType; + /// \brief Limit on the number of items for a single kernel launch. static constexpr unsigned int size_limit = SizeLimit; constexpr transform_config() : detail::transform_config_params{ - {BlockSize, ItemsPerThread, SizeLimit} + {BlockSize, ItemsPerThread, SizeLimit}, + LoadType + } + {} +#endif +}; + +/// \brief Configuration for the device-level transform operation for pointers. +/// \tparam BlockSize Number of threads in a block. +/// \tparam ItemsPerThread Number of items processed by each thread. +/// \tparam SizeLimit Limit on the number of items for a single kernel launch. +template +struct transform_pointer_config : public detail::transform_config_params +{ + /// \brief Identifies the algorithm associated to the config. + using tag = detail::transform_config_tag; +#ifndef DOXYGEN_SHOULD_SKIP_THIS + + /// \brief Number of threads in a block. + static constexpr unsigned int block_size = BlockSize; + + /// \brief Number of items processed by each thread. + static constexpr unsigned int items_per_thread = ItemsPerThread; + + /// \brief The type of thread_load being used. + static constexpr cache_load_modifier load_type = LoadType; + + /// \brief Limit on the number of items for a single kernel launch. + static constexpr unsigned int size_limit = SizeLimit; + + constexpr transform_pointer_config() + : detail::transform_config_params{ + {BlockSize, ItemsPerThread, SizeLimit}, + LoadType } {} #endif @@ -688,6 +717,15 @@ struct transform_config : public detail::transform_config_params namespace detail { +template +struct default_transform_pointer_config_base +{ + static constexpr unsigned int item_scale + = ::rocprim::detail::ceiling_div(sizeof(uint128_t), sizeof(Value)); + + using type = transform_config<256, item_scale>; +}; + template struct default_transform_config_base { @@ -713,7 +751,7 @@ struct lower_bound_config_tag : public transform_config_tag template -struct binary_search_config : transform_config +struct binary_search_config : transform_config { /// \brief Identifies the algorithm associated to the config. using tag = detail::binary_search_config_tag; @@ -726,7 +764,7 @@ struct binary_search_config : transform_config -struct upper_bound_config : transform_config +struct upper_bound_config : transform_config { /// \brief Identifies the algorithm associated to the config. using tag = detail::upper_bound_config_tag; @@ -739,7 +777,7 @@ struct upper_bound_config : transform_config -struct lower_bound_config : transform_config +struct lower_bound_config : transform_config { /// \brief Identifies the algorithm associated to the config. using tag = detail::lower_bound_config_tag; @@ -1402,8 +1440,8 @@ namespace detail { struct search_n_config_params { - size_t threshold; kernel_config_params kernel_config; + size_t threshold; }; } // namespace detail @@ -1411,18 +1449,35 @@ struct search_n_config_params /// /// \tparam BlockSize number of threads in a block. /// \tparam ItemsPerThread number of items processed by each thread. -template +template struct search_n_config : public detail::search_n_config_params { #ifndef DOXYGEN_DOCUMENTATION_BUILD constexpr search_n_config() : detail::search_n_config_params{ - 8, {BlockSize, ItemsPerThread, 0} + {BlockSize, ItemsPerThread, ROCPRIM_GRID_SIZE_LIMIT}, + Threshold } {} #endif }; +namespace detail +{ +template +struct default_search_n_config_base +{ + static constexpr unsigned int item_scale + = ::rocprim::detail::ceiling_div(sizeof(InputType), sizeof(int)); + + using type + = search_n_config::value, + ::rocprim::max(1u, 10u / item_scale), + 8>; +}; + +} // namespace detail + namespace detail { @@ -1465,9 +1520,9 @@ struct default_merge_config_base static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( ::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - using type = merge_config::value, + using type = merge_config::value, ::rocprim::max(1u, 10u / item_scale)>; }; @@ -1477,12 +1532,12 @@ struct default_merge_config_base static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); - using type - = select_type>, - select_type_case>, - select_type_case>, - merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>>; + using type = select_type< + select_type_case>, + select_type_case>, + select_type_case>, + merge_config::value, + ::rocprim::max(1u, 10u / item_scale)>>; }; } // namespace detail diff --git a/rocprim/include/rocprim/device/detail/device_merge.hpp b/rocprim/include/rocprim/device/detail/device_merge.hpp index 54abb8b41..f87467f10 100644 --- a/rocprim/include/rocprim/device/detail/device_merge.hpp +++ b/rocprim/include/rocprim/device/detail/device_merge.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,14 +21,17 @@ #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_HPP_ #define ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_HPP_ -#include #include +#include #include "../../config.hpp" #include "../../detail/various.hpp" +#include "../config_types.hpp" +#include "../device_merge_config.hpp" +#include "device_config_helper.hpp" -#include "../../intrinsics.hpp" #include "../../functional.hpp" +#include "../../intrinsics.hpp" #include "../../types.hpp" #include "../../block/block_store.hpp" @@ -53,20 +56,18 @@ range_t<> compute_range(const unsigned int id, return range_t<>{p1, p2, diag1 - p1, diag2 - p2}; } -template< - class IndexIterator, - class KeysInputIterator1, - class KeysInputIterator2, - class BinaryFunction -> +template ROCPRIM_DEVICE ROCPRIM_INLINE -void partition_kernel_impl(IndexIterator indices, +void partition_kernel_impl(IndexIterator indices, KeysInputIterator1 keys_input1, KeysInputIterator2 keys_input2, - const size_t input1_size, - const size_t input2_size, + const size_t input1_size, + const size_t input2_size, const unsigned int spacing, - BinaryFunction compare_function) + BinaryFunction compare_function) { const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); @@ -89,20 +90,18 @@ void partition_kernel_impl(IndexIterator indices, indices[id] = begin; } -template< - unsigned int BlockSize, - unsigned int ItemsPerThread, - class KeysInputIterator1, - class KeysInputIterator2, - class KeyType -> +template ROCPRIM_DEVICE ROCPRIM_INLINE -void load(unsigned int flat_id, +void load(unsigned int flat_id, KeysInputIterator1 keys_input1, KeysInputIterator2 keys_input2, - KeyType * keys_shared, - const size_t input1_size, - const size_t input2_size) + KeyType* keys_shared, + const size_t input1_size, + const size_t input2_size) { ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; ++i) @@ -114,7 +113,7 @@ void load(unsigned int flat_id, } else if(index < input1_size + input2_size) { - keys_shared[index] = keys_input2[index - input1_size]; + keys_shared[index] = static_cast(keys_input2[index - input1_size]); } } @@ -137,23 +136,22 @@ void merge_keys(unsigned int flat_id, range_t<> range, BinaryFunction compare_function) { - load( - flat_id, keys_input1 + range.begin1, keys_input2 + range.begin2, - keys_shared, range.count1(), range.count2() - ); + load(flat_id, + keys_input1 + range.begin1, + keys_input2 + range.begin2, + keys_shared, + range.count1(), + range.count2()); range_t<> range_local{0, range.count1(), range.count1(), (range.count1() + range.count2())}; - unsigned int diag = ItemsPerThread * flat_id; - unsigned int partition = - merge_path( - keys_shared + range_local.begin1, - keys_shared + range_local.begin2, - range_local.count1(), - range_local.count2(), - diag, - compare_function - ); + unsigned int diag = ItemsPerThread * flat_id; + unsigned int partition = merge_path(keys_shared + range_local.begin1, + keys_shared + range_local.begin2, + range_local.count1(), + range_local.count2(), + diag, + compare_function); range_t<> range_partition{range_local.begin1 + partition, range_local.end1, @@ -163,23 +161,20 @@ void merge_keys(unsigned int flat_id, serial_merge(keys_shared, key_inputs, index, range_partition, compare_function); } -template< - bool WithValues, - unsigned int BlockSize, - class ValuesInputIterator1, - class ValuesInputIterator2, - class ValuesOutputIterator, - unsigned int ItemsPerThread -> +template ROCPRIM_DEVICE ROCPRIM_INLINE -typename std::enable_if::type -merge_values(unsigned int flat_id, - ValuesInputIterator1 values_input1, - ValuesInputIterator2 values_input2, - ValuesOutputIterator values_output, - unsigned int (&index)[ItemsPerThread], - const size_t input1_size, - const size_t input2_size) +typename std::enable_if::type merge_values(unsigned int flat_id, + ValuesInputIterator1 values_input1, + ValuesInputIterator2 values_input2, + ValuesOutputIterator values_output, + unsigned int (&index)[ItemsPerThread], + const size_t input1_size, + const size_t input2_size) { using value_type = typename std::iterator_traits::value_type; @@ -192,8 +187,9 @@ merge_values(unsigned int flat_id, ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; ++i) { - values[i] = (index[i] < input1_size) ? values_input1[index[i]] : - values_input2[index[i] - input1_size]; + values[i] = (index[i] < input1_size) + ? values_input1[index[i]] + : static_cast(values_input2[index[i] - input1_size]); } } else @@ -203,52 +199,69 @@ merge_values(unsigned int flat_id, { if(flat_id * ItemsPerThread + i < count) { - values[i] = (index[i] < input1_size) ? values_input1[index[i]] : - values_input2[index[i] - input1_size]; + values[i] = (index[i] < input1_size) + ? values_input1[index[i]] + : static_cast(values_input2[index[i] - input1_size]); } } } ::rocprim::syncthreads(); - block_store_direct_blocked( - flat_id, - values_output, - values, - count - ); + block_store_direct_blocked(flat_id, values_output, values, count); } -template< - bool WithValues, - unsigned int BlockSize, - class ValuesInputIterator1, - class ValuesInputIterator2, - class ValuesOutputIterator, - unsigned int ItemsPerThread -> +template ROCPRIM_DEVICE ROCPRIM_INLINE -typename std::enable_if::type -merge_values(unsigned int flat_id, - ValuesInputIterator1 values_input1, - ValuesInputIterator2 values_input2, - ValuesOutputIterator values_output, - unsigned int (&index)[ItemsPerThread], - const size_t input1_size, - const size_t input2_size) +typename std::enable_if::type merge_values(unsigned int flat_id, + ValuesInputIterator1 values_input1, + ValuesInputIterator2 values_input2, + ValuesOutputIterator values_output, + unsigned int (&index)[ItemsPerThread], + const size_t input1_size, + const size_t input2_size) { - (void) flat_id; - (void) values_input1; - (void) values_input2; - (void) values_output; - (void) index; - (void) input1_size; - (void) input2_size; + (void)flat_id; + (void)values_input1; + (void)values_input2; + (void)values_output; + (void)index; + (void)input1_size; + (void)input2_size; } -template< - unsigned int BlockSize, - unsigned int ItemsPerThread, +template +struct merge_kernel_impl_ +{ + static constexpr merge_config_params params = device_params(); + + static constexpr unsigned int block_size = params.kernel_config.block_size; + static constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; + static constexpr unsigned int items_per_block = block_size * items_per_thread; + static constexpr unsigned int input_block_size = block_size * items_per_thread + 1; + static constexpr bool with_values = !std::is_same::value; + + // Block primitives + using keys_store_type + = ::rocprim::block_store; + + union storage_type + { + ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_WITH_PUSH + typename detail::raw_storage keys_shared; + ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_POP + typename keys_store_type::storage_type keys_store; + }; + + template< class IndexIterator, class KeysInputIterator1, class KeysInputIterator2, @@ -258,8 +271,8 @@ template< class ValuesOutputIterator, class BinaryFunction > -ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE -void merge_kernel_impl(IndexIterator indices, + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE +void merge(IndexIterator indices, KeysInputIterator1 keys_input1, KeysInputIterator2 keys_input2, KeysOutputIterator keys_output, @@ -268,79 +281,75 @@ void merge_kernel_impl(IndexIterator indices, ValuesOutputIterator values_output, const size_t input1_size, const size_t input2_size, - BinaryFunction compare_function) -{ - using key_type = typename std::iterator_traits::value_type; - using value_type = typename std::iterator_traits::value_type; - using keys_store_type = ::rocprim::block_store< - key_type, BlockSize, ItemsPerThread, - ::rocprim::block_store_method::block_store_transpose - >; - constexpr bool with_values = !std::is_same::value; + BinaryFunction compare_function, + storage_type& storage) + { + using key_type1 = typename std::iterator_traits::value_type; + using key_type2 = typename std::iterator_traits::value_type; + using value_type1 = typename std::iterator_traits::value_type; + using value_type2 = typename std::iterator_traits::value_type; - constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; - constexpr unsigned int input_block_size = BlockSize * ItemsPerThread + 1; + if constexpr(with_values) + { + static_assert(std::is_convertible_v, + "values_input2 must be convertible to values_input1"); + } - ROCPRIM_SHARED_MEMORY union - { - ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_WITH_PUSH - typename detail::raw_storage keys_shared; - ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_POP - typename keys_store_type::storage_type keys_store; - } storage; + static_assert(std::is_convertible_v, + "Keys_input2 must be convertible to keys_input1"); - key_type input[ItemsPerThread]; - unsigned int index[ItemsPerThread]; + Key input[items_per_thread]; + unsigned int index[items_per_thread]; - const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); - const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); - const unsigned int block_offset = flat_block_id * items_per_block; - const unsigned int count = input1_size + input2_size; - const unsigned int valid_in_last_block = count - block_offset; - const bool is_incomplete_block = valid_in_last_block < items_per_block; + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int block_offset = flat_block_id * items_per_block; + const unsigned int count = input1_size + input2_size; + const unsigned int valid_in_last_block = count - block_offset; + const bool is_incomplete_block = valid_in_last_block < items_per_block; - const unsigned int partitions = (count + items_per_block - 1) / items_per_block; + const unsigned int partitions = (count + items_per_block - 1) / items_per_block; - const unsigned int p1 = indices[rocprim::min(flat_block_id, partitions)]; - const unsigned int p2 = indices[rocprim::min(flat_block_id + 1, partitions)]; + const unsigned int p1 = indices[rocprim::min(flat_block_id, partitions)]; + const unsigned int p2 = indices[rocprim::min(flat_block_id + 1, partitions)]; - range_t<> range - = compute_range(flat_block_id, input1_size, input2_size, items_per_block, p1, p2); + range_t<> range + = compute_range(flat_block_id, input1_size, input2_size, items_per_block, p1, p2); - merge_keys( - flat_id, keys_input1, keys_input2, input, index, - storage.keys_shared.get(), - range, compare_function - ); + merge_keys(flat_id, + keys_input1, + keys_input2, + input, + index, + storage.keys_shared.get(), + range, + compare_function); - ::rocprim::syncthreads(); + ::rocprim::syncthreads(); - if(is_incomplete_block) // # elements in last block may not equal items_per_block for the last block - { - keys_store_type().store( - keys_output + block_offset, - input, - valid_in_last_block, - storage.keys_store - ); - } - else - { - keys_store_type().store( - keys_output + block_offset, - input, - storage.keys_store - ); - } + if(is_incomplete_block) // # elements in last block may not equal items_per_block for the last block + { + keys_store_type().store(keys_output + block_offset, + input, + valid_in_last_block, + storage.keys_store); + } + else + { + keys_store_type().store(keys_output + block_offset, input, storage.keys_store); + } - merge_values( - flat_id, values_input1 + range.begin1, values_input2 + range.begin2, - values_output + block_offset, index, - range.count1(), range.count2() - ); -} + merge_values(flat_id, + values_input1 + range.begin1, + values_input2 + range.begin2, + values_output + block_offset, + index, + range.count1(), + range.count2()); + } +}; -} // end of detail namespace +} // namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/detail/device_merge_sort.hpp b/rocprim/include/rocprim/device/detail/device_merge_sort.hpp index bce843af4..6dd152abe 100644 --- a/rocprim/include/rocprim/device/detail/device_merge_sort.hpp +++ b/rocprim/include/rocprim/device/detail/device_merge_sort.hpp @@ -21,14 +21,14 @@ #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_ #define ROCPRIM_DEVICE_DETAIL_DEVICE_MERGE_SORT_HPP_ -#include #include +#include #include "../../config.hpp" #include "../../detail/various.hpp" -#include "../../intrinsics.hpp" #include "../../functional.hpp" +#include "../../intrinsics.hpp" #include "../../types.hpp" #include "../../block/block_load.hpp" @@ -41,73 +41,66 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - bool WithValues, - unsigned int BlockSize, - unsigned int ItemsPerThread, - class Key, - class Value -> -struct block_store_impl { +template +struct block_store_impl +{ using block_store_type = block_store; using storage_type = typename block_store_type::storage_type; template - ROCPRIM_DEVICE ROCPRIM_INLINE void store(const OffsetT block_offset, - const unsigned int valid_in_last_block, - const bool is_incomplete_block, - KeysOutputIterator keys_output, - ValuesOutputIterator /*values_output*/, - Key (&keys)[ItemsPerThread], - Value (&/*values*/)[ItemsPerThread], - storage_type& storage) + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(const OffsetT block_offset, + const unsigned int valid_in_last_block, + const bool is_incomplete_block, + KeysOutputIterator keys_output, + ValuesOutputIterator /*values_output*/, + Key (&keys)[ItemsPerThread], + Value (& /*values*/)[ItemsPerThread], + storage_type& storage) { // Synchronize before reusing shared memory ::rocprim::syncthreads(); if(is_incomplete_block) { - block_store_type().store( - keys_output + block_offset, - keys, - valid_in_last_block, - storage - ); + block_store_type().store(keys_output + block_offset, + keys, + valid_in_last_block, + storage); } else { - block_store_type().store( - keys_output + block_offset, - keys, - storage - ); + block_store_type().store(keys_output + block_offset, keys, storage); } } }; -template< - unsigned int BlockSize, - unsigned int ItemsPerThread, - class Key, - class Value -> -struct block_store_impl { - using block_store_key_type = block_store; - using block_store_value_type = block_store; - - union storage_type { +template +struct block_store_impl +{ + using block_store_key_type + = block_store; + using block_store_value_type + = block_store; + + union storage_type + { typename block_store_key_type::storage_type keys; typename block_store_value_type::storage_type values; }; - template + template ROCPRIM_DEVICE ROCPRIM_INLINE - void store(const OffsetT block_offset, - const unsigned int valid_in_last_block, - const bool is_incomplete_block, - KeysOutputIterator keys_output, + void store(const OffsetT block_offset, + const unsigned int valid_in_last_block, + const bool is_incomplete_block, + KeysOutputIterator keys_output, ValuesOutputIterator values_output, Key (&keys)[ItemsPerThread], Value (&values)[ItemsPerThread], @@ -118,37 +111,25 @@ struct block_store_impl { if(is_incomplete_block) { - block_store_key_type().store( - keys_output + block_offset, - keys, - valid_in_last_block, - storage.keys - ); + block_store_key_type().store(keys_output + block_offset, + keys, + valid_in_last_block, + storage.keys); ::rocprim::syncthreads(); - block_store_value_type().store( - values_output + block_offset, - values, - valid_in_last_block, - storage.values - ); + block_store_value_type().store(values_output + block_offset, + values, + valid_in_last_block, + storage.values); } else { - block_store_key_type().store( - keys_output + block_offset, - keys, - storage.keys - ); + block_store_key_type().store(keys_output + block_offset, keys, storage.keys); ::rocprim::syncthreads(); - block_store_value_type().store( - values_output + block_offset, - values, - storage.values - ); + block_store_value_type().store(values_output + block_offset, values, storage.values); } } }; @@ -171,10 +152,11 @@ struct block_permute_values_impl }; template - ROCPRIM_DEVICE void permute(unsigned int (&ranks)[ItemsPerThread], - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - storage_type& storage) + ROCPRIM_DEVICE + void permute(unsigned int (&ranks)[ItemsPerThread], + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + storage_type& storage) { syncthreads(); const auto flat_id = block_thread_id<0>(); @@ -186,11 +168,12 @@ struct block_permute_values_impl } template - ROCPRIM_DEVICE void permute(unsigned int (&ranks)[ItemsPerThread], - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - const unsigned int valid_in_last_block, - storage_type& storage) + ROCPRIM_DEVICE + void permute(unsigned int (&ranks)[ItemsPerThread], + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const unsigned int valid_in_last_block, + storage_type& storage) { syncthreads(); const auto flat_id = block_thread_id<0>(); @@ -208,10 +191,11 @@ struct block_permute_values_impl using storage_type = empty_storage_type; template - ROCPRIM_DEVICE void permute(unsigned int (&ranks)[ItemsPerThread], - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - storage_type& storage) + ROCPRIM_DEVICE + void permute(unsigned int (&ranks)[ItemsPerThread], + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + storage_type& storage) { (void)ranks; (void)values_input; @@ -220,11 +204,12 @@ struct block_permute_values_impl } template - ROCPRIM_DEVICE void permute(unsigned int (&ranks)[ItemsPerThread], - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - const unsigned int valid_in_last_block, - storage_type& storage) + ROCPRIM_DEVICE + void permute(unsigned int (&ranks)[ItemsPerThread], + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const unsigned int valid_in_last_block, + storage_type& storage) { (void)ranks; (void)values_input; @@ -259,10 +244,11 @@ struct block_permute_values_impl - ROCPRIM_DEVICE void permute(unsigned int (&ranks)[ItemsPerThread], - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - storage_type& storage_) + ROCPRIM_DEVICE + void permute(unsigned int (&ranks)[ItemsPerThread], + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + storage_type& storage_) { syncthreads(); auto& values_shared = storage_.get().values; @@ -285,11 +271,12 @@ struct block_permute_values_impl - ROCPRIM_DEVICE void permute(unsigned int (&ranks)[ItemsPerThread], - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - const unsigned int valid_in_last_block, - storage_type& storage_) + ROCPRIM_DEVICE + void permute(unsigned int (&ranks)[ItemsPerThread], + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + const unsigned int valid_in_last_block, + storage_type& storage_) { syncthreads(); auto& values_shared = storage_.get().values; @@ -352,14 +339,15 @@ struct block_sort_impl typename ValuesInputIterator, typename ValuesOutputIterator, typename BinaryFunction> - ROCPRIM_DEVICE void sort(unsigned int valid_in_last_block, - const bool is_incomplete_block, - KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator /*values_input*/, - ValuesOutputIterator /*values_output*/, - BinaryFunction compare_function, - storage_type& storage) + ROCPRIM_DEVICE + void sort(unsigned int valid_in_last_block, + const bool is_incomplete_block, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator /*values_input*/, + ValuesOutputIterator /*values_output*/, + BinaryFunction compare_function, + storage_type& storage) { Key keys[ItemsPerThread]; @@ -422,14 +410,15 @@ struct block_sort_impl - ROCPRIM_DEVICE void sort(const unsigned int valid_in_last_block, - const bool is_incomplete_block, - KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - BinaryFunction compare_function, - storage_type& storage) + ROCPRIM_DEVICE + void sort(const unsigned int valid_in_last_block, + const bool is_incomplete_block, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + BinaryFunction compare_function, + storage_type& storage) { Key keys[ItemsPerThread]; Value values[ItemsPerThread]; @@ -497,14 +486,15 @@ struct block_sort_impl - ROCPRIM_DEVICE void sort(const unsigned int valid_in_last_block, - const bool is_incomplete_block, - KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - BinaryFunction compare_function, - storage_type& storage) + ROCPRIM_DEVICE + void sort(const unsigned int valid_in_last_block, + const bool is_incomplete_block, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + BinaryFunction compare_function, + storage_type& storage) { Key keys[ItemsPerThread]; @@ -589,7 +579,7 @@ void block_merge_oddeven_kernel(KeysInputIterator keys_input, } block_load_direct_blocked(flat_id, keys_input + block_offset, keys, valid_in_last_block); - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { block_load_direct_blocked(flat_id, values_input + block_offset, @@ -600,7 +590,7 @@ void block_merge_oddeven_kernel(KeysInputIterator keys_input, else { block_load_direct_blocked(flat_id, keys_input + block_offset, keys); - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { block_load_direct_blocked(flat_id, values_input + block_offset, values); } @@ -630,7 +620,7 @@ void block_merge_oddeven_kernel(KeysInputIterator keys_input, if(id < input_size) { keys_output[id] = keys[i]; - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { values_output[id] = values[i]; } @@ -644,7 +634,7 @@ void block_merge_oddeven_kernel(KeysInputIterator keys_input, { const OffsetT id = block_offset + thread_offset + i; keys_output[id] = keys[i]; - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { values_output[id] = values[i]; } @@ -675,7 +665,7 @@ void block_merge_oddeven_kernel(KeysInputIterator keys_input, OffsetT offset = dest_offset + i + left_id; // Destination offset (target calculation) keys_output[offset] = keys[i]; - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { values_output[offset] = values[i]; } @@ -702,7 +692,7 @@ void block_merge_oddeven_kernel(KeysInputIterator keys_input, } } -} // end of detail namespace +} // namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/detail/device_merge_sort_mergepath.hpp b/rocprim/include/rocprim/device/detail/device_merge_sort_mergepath.hpp index 1778c2f13..67557323e 100644 --- a/rocprim/include/rocprim/device/detail/device_merge_sort_mergepath.hpp +++ b/rocprim/include/rocprim/device/detail/device_merge_sort_mergepath.hpp @@ -33,93 +33,91 @@ #include "../../detail/various.hpp" -#include "device_merge_sort.hpp" #include "device_merge.hpp" +#include "device_merge_sort.hpp" BEGIN_ROCPRIM_NAMESPACE namespace detail { - // Load items from input1 and input2 from global memory - template +// Load items from input1 and input2 from global memory +template ROCPRIM_DEVICE ROCPRIM_INLINE - void gmem_to_reg(KeyT (&output)[ItemsPerThread], - InputIterator input1, - InputIterator input2, - unsigned int count1, - unsigned int count2, - bool IsLastTile) +void gmem_to_reg(KeyT (&output)[ItemsPerThread], + InputIterator input1, + InputIterator input2, + unsigned int count1, + unsigned int count2, + bool IsLastTile) +{ + if(IsLastTile) { - if(IsLastTile) - { - ROCPRIM_UNROLL - for (unsigned int item = 0; item < ItemsPerThread; ++item) - { - unsigned int idx = rocprim::flat_block_size() * item + threadIdx.x; - if (idx < count1 + count2) - { - output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; - } - } - - } - else + ROCPRIM_UNROLL + for(unsigned int item = 0; item < ItemsPerThread; ++item) { - ROCPRIM_UNROLL - for (unsigned int item = 0; item < ItemsPerThread; ++item) + unsigned int idx = rocprim::flat_block_size() * item + threadIdx.x; + if(idx < count1 + count2) { - unsigned int idx = rocprim::flat_block_size() * item + threadIdx.x; output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; } } } - - template - ROCPRIM_DEVICE ROCPRIM_INLINE - void reg_to_shared(OutputIterator output, - KeyT (&input)[ItemsPerThread]) + else { ROCPRIM_UNROLL - for (unsigned int item = 0; item < ItemsPerThread; ++item) + for(unsigned int item = 0; item < ItemsPerThread; ++item) { - unsigned int idx = BlockSize * item + threadIdx.x; - output[idx] = input[item]; + unsigned int idx = rocprim::flat_block_size() * item + threadIdx.x; + output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; } } +} - template - struct block_merge_impl; - - template - struct block_merge_impl::value - || rocprim::is_floating_point::value - || std::is_integral::value>> +template + ROCPRIM_DEVICE ROCPRIM_INLINE +void reg_to_shared(OutputIterator output, KeyT (&input)[ItemsPerThread]) +{ + ROCPRIM_UNROLL + for(unsigned int item = 0; item < ItemsPerThread; ++item) { + unsigned int idx = BlockSize * item + threadIdx.x; + output[idx] = input[item]; + } +} + +template +struct block_merge_impl; + +template +struct block_merge_impl< + Key, + Value, + BlockSize, + ItemsPerThread, + std::enable_if_t::value + || rocprim::is_floating_point::value || std::is_integral::value>> +{ - static constexpr bool with_values = !std::is_same::value; - static constexpr unsigned int items_per_tile = BlockSize * ItemsPerThread; + static constexpr bool with_values = !std::is_same::value; + static constexpr unsigned int items_per_tile = BlockSize * ItemsPerThread; - using block_store = block_store_impl; + using block_store = block_store_impl; - using keys_storage_ = Key[items_per_tile + 1]; - using values_storage_ = Value[items_per_tile + 1]; + using keys_storage_ = Key[items_per_tile + 1]; + using values_storage_ = Value[items_per_tile + 1]; - union storage_type - { - typename block_store::storage_type store; - ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_WITH_PUSH - detail::raw_storage keys; - detail::raw_storage values; - ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_POP - }; + union storage_type + { + typename block_store::storage_type store; + ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_WITH_PUSH + detail::raw_storage keys; + detail::raw_storage values; + ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_POP + }; template - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void process_tile(KeysInputIterator keys_input, + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void process_tile(KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, @@ -194,7 +192,7 @@ namespace detail reg_to_shared(keys_shared, keys); Value values[ItemsPerThread]; - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { gmem_to_reg(values, values_input + keys1_beg, @@ -228,7 +226,7 @@ namespace detail serial_merge(keys_shared, keys, indices, range_local, compare_function); rocprim::syncthreads(); - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { reg_to_shared(values_shared, values); @@ -253,38 +251,38 @@ namespace detail values, storage.store); } - }; - - // The specialization below exists because the compiler creates slow code for - // ValueTypes with misaligned datastructures in them (e.g. custom_char_double) - // when storing/loading those ValueTypes to/from registers. - // Thus this is a temporary workaround. - template - struct block_merge_impl::value - && !rocprim::is_floating_point::value - && !std::is_integral::value>> - { +}; + +// The specialization below exists because the compiler creates slow code for +// ValueTypes with misaligned datastructures in them (e.g. custom_char_double) +// when storing/loading those ValueTypes to/from registers. +// Thus this is a temporary workaround. +template +struct block_merge_impl::value + && !rocprim::is_floating_point::value + && !std::is_integral::value>> +{ - static constexpr bool with_values = !std::is_same::value; - static constexpr unsigned int items_per_tile = BlockSize * ItemsPerThread; + static constexpr bool with_values = !std::is_same::value; + static constexpr unsigned int items_per_tile = BlockSize * ItemsPerThread; - using block_store = block_store_impl; + using block_store = block_store_impl; - using keys_storage_ = Key[items_per_tile + 1]; - using values_storage_ = Value[items_per_tile + 1]; + using keys_storage_ = Key[items_per_tile + 1]; + using values_storage_ = Value[items_per_tile + 1]; - union storage_type - { - typename block_store::storage_type store; - ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_WITH_PUSH - detail::raw_storage keys; - detail::raw_storage values; - ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_POP - }; + union storage_type + { + typename block_store::storage_type store; + ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_WITH_PUSH + detail::raw_storage keys; + detail::raw_storage values; + ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_POP + }; template - ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void process_tile(KeysInputIterator keys_input, + ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void process_tile(KeysInputIterator keys_input, KeysOutputIterator keys_output, ValuesInputIterator values_input, ValuesOutputIterator values_output, @@ -382,7 +380,7 @@ namespace detail serial_merge(keys_shared, keys, indices, range_local, compare_function); rocprim::syncthreads(); - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { const ValuesInputIterator input1 = values_input + keys1_beg; const ValuesInputIterator input2 = values_input + keys2_beg; @@ -456,9 +454,9 @@ namespace detail values, storage.store); } - }; +}; -} // end of detail namespace +} // namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/detail/device_radix_sort.hpp b/rocprim/include/rocprim/device/detail/device_radix_sort.hpp index 5c1cec1ea..ca35ad7d6 100644 --- a/rocprim/include/rocprim/device/detail/device_radix_sort.hpp +++ b/rocprim/include/rocprim/device/detail/device_radix_sort.hpp @@ -21,8 +21,8 @@ #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_ #define ROCPRIM_DEVICE_DETAIL_DEVICE_RADIX_SORT_HPP_ -#include #include +#include #include "../../config.hpp" #include "../../detail/various.hpp" @@ -39,7 +39,6 @@ #include "../../block/block_radix_sort.hpp" #include "../../block/block_scan.hpp" #include "../../block/block_store_func.hpp" -#include "../../thread/radix_key_codec.hpp" BEGIN_ROCPRIM_NAMESPACE @@ -63,7 +62,7 @@ void sort_warp_striped_to_striped(SortType sorter, unsigned int begin_bit, unsigned int end_bit) { - if ROCPRIM_IF_CONSTEXPR(Descending) + if constexpr(Descending) { sorter.sort_desc_warp_striped_to_striped(keys, values, @@ -92,8 +91,8 @@ void sort_warp_striped_to_striped(SortType sorter, unsigned int begin_bit, unsigned int end_bit) { - (void) values; - if ROCPRIM_IF_CONSTEXPR(Descending) + (void)values; + if constexpr(Descending) { sorter.sort_desc_warp_striped_to_striped(keys, storage, begin_bit, end_bit, decomposer); } @@ -103,17 +102,15 @@ void sort_warp_striped_to_striped(SortType sorter, } } -template< - unsigned int WarpSize, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int RadixBits, - bool Descending -> +template struct radix_digit_count_helper { static constexpr unsigned int radix_size = 1 << RadixBits; - static constexpr unsigned int warp_size = WarpSize; + static constexpr unsigned int warp_size = WarpSize; static constexpr unsigned int atomic_stripes = 4; static constexpr unsigned int counters = radix_size * atomic_stripes; @@ -138,25 +135,21 @@ struct radix_digit_count_helper return digit * atomic_stripes + stripe; } - template< - bool IsFull = false, - class KeysInputIterator, - class Offset - > + template ROCPRIM_DEVICE ROCPRIM_INLINE void count_digits(KeysInputIterator keys_input, - Offset begin_offset, - Offset end_offset, - unsigned int bit, - unsigned int current_radix_bits, - storage_type& storage, - unsigned int& digit_count) // i-th thread will get i-th digit's value + Offset begin_offset, + Offset end_offset, + unsigned int bit, + unsigned int current_radix_bits, + storage_type& storage, + unsigned int& digit_count) // i-th thread will get i-th digit's value { constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; using key_type = typename std::iterator_traits::value_type; - - using key_codec = ::rocprim::radix_key_codec; + using key_codec + = decltype(::rocprim::traits::get().template radix_key_codec()); using bit_key_type = typename key_codec::bit_key_type; const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); @@ -176,9 +169,10 @@ struct radix_digit_count_helper ::rocprim::syncthreads(); - for(Offset block_offset = begin_offset; block_offset < end_offset; block_offset += items_per_block) + for(Offset block_offset = begin_offset; block_offset < end_offset; + block_offset += items_per_block) { - key_type keys[ItemsPerThread]; + key_type keys[ItemsPerThread]; unsigned int valid_count; // Use loading into a striped arrangement because an order of items is irrelevant, // only totals matter @@ -190,14 +184,18 @@ struct radix_digit_count_helper else { valid_count = end_offset - block_offset; - block_load_direct_striped(flat_id, keys_input + block_offset, keys, valid_count); + block_load_direct_striped(flat_id, + keys_input + block_offset, + keys, + valid_count); } ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { const bit_key_type bit_key = key_codec::encode(keys[i]); - const unsigned int digit = key_codec::extract_digit(bit_key, bit, current_radix_bits); + const unsigned int digit + = key_codec::extract_digit(bit_key, bit, current_radix_bits); const unsigned int pos = i * BlockSize + flat_id; if(IsFull || pos < valid_count) @@ -222,18 +220,16 @@ struct radix_digit_count_helper } }; -template< - unsigned int BlockSize, - unsigned int ItemsPerThread, - bool Descending, - class Key, - class Value -> +template struct radix_sort_single_helper { static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; - using key_type = Key; + using key_type = Key; using value_type = Value; using sort_type = ::rocprim::block_radix_sort; @@ -250,32 +246,33 @@ struct radix_sort_single_helper class ValuesInputIterator, class ValuesOutputIterator, class Decomposer> - ROCPRIM_DEVICE ROCPRIM_INLINE void sort_single(KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - unsigned int size, - Decomposer decomposer, - unsigned int bit, - unsigned int current_radix_bits, - storage_type& storage) + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_single(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + unsigned int size, + Decomposer decomposer, + unsigned int bit, + unsigned int current_radix_bits, + storage_type& storage) { - const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); - const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); const unsigned int block_offset = flat_block_id * items_per_block; const bool is_incomplete_block = flat_block_id == (size / items_per_block); const unsigned int valid_in_last_block = size - block_offset; using key_type = typename std::iterator_traits::value_type; + using key_codec + = decltype(::rocprim::traits::get().template radix_key_codec()); - using key_codec = radix_key_codec; - - key_type keys[ItemsPerThread]; + key_type keys[ItemsPerThread]; value_type values[ItemsPerThread]; if(!is_incomplete_block) { block_load_direct_warp_striped(flat_id, keys_input + block_offset, keys); - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { block_load_direct_warp_striped(flat_id, values_input + block_offset, values); } @@ -288,7 +285,7 @@ struct radix_sort_single_helper keys, valid_in_last_block, out_of_bounds); - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { block_load_direct_warp_striped(flat_id, values_input + block_offset, @@ -309,7 +306,7 @@ struct radix_sort_single_helper if(!is_incomplete_block) { block_store_direct_striped(flat_id, keys_output + block_offset, keys); - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { block_store_direct_striped(flat_id, values_output + block_offset, @@ -322,7 +319,7 @@ struct radix_sort_single_helper keys_output + block_offset, keys, valid_in_last_block); - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { block_store_direct_striped(flat_id, values_output + block_offset, @@ -343,12 +340,13 @@ template struct radix_sort_and_scatter_helper { - static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; - static constexpr unsigned int radix_size = 1 << RadixBits; + static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; + static constexpr unsigned int radix_size = 1 << RadixBits; static constexpr unsigned int digits_per_thread = 1; static constexpr bool with_values = !std::is_same::value; - using key_codec = radix_key_codec; + using key_codec + = decltype(::rocprim::traits::get().template radix_key_codec()); using radix_rank_type = ::rocprim::block_radix_rank; static constexpr bool load_warp_striped @@ -378,9 +376,9 @@ struct radix_sort_and_scatter_helper class ValuesInputIterator, class ValuesOutputIterator> ROCPRIM_DEVICE ROCPRIM_INLINE - void sort_and_scatter(KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, + void sort_and_scatter(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, ValuesOutputIterator values_output, Offset begin_offset, Offset end_offset, @@ -397,7 +395,8 @@ struct radix_sort_and_scatter_helper storage.digit_offsets[flat_id] = digit_start; } - for(Offset block_offset = begin_offset; block_offset < end_offset; block_offset += items_per_block) + for(Offset block_offset = begin_offset; block_offset < end_offset; + block_offset += items_per_block) { Key keys[ItemsPerThread]; @@ -405,7 +404,7 @@ struct radix_sort_and_scatter_helper if(IsFull || (block_offset + items_per_block <= end_offset)) { valid_items = items_per_block; - if ROCPRIM_IF_CONSTEXPR(load_warp_striped) + if constexpr(load_warp_striped) { block_load_direct_warp_striped(flat_id, keys_input + block_offset, keys); } @@ -424,7 +423,7 @@ struct radix_sort_and_scatter_helper // it does not matter. It does cause the final digit offset to be increased past its end, // but again this does not matter since this is the last iteration in which it will be used anyway. const Key out_of_bounds = key_codec::get_out_of_bounds_key(); - if ROCPRIM_IF_CONSTEXPR(load_warp_striped) + if constexpr(load_warp_striped) { block_load_direct_warp_striped(flat_id, keys_input + block_offset, @@ -495,12 +494,12 @@ struct radix_sort_and_scatter_helper } // Gather and scatter values if necessary - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { Value values[ItemsPerThread]; - if ROCPRIM_IF_CONSTEXPR(IsFull) + if constexpr(IsFull) { - if ROCPRIM_IF_CONSTEXPR(load_warp_striped) + if constexpr(load_warp_striped) { block_load_direct_warp_striped(flat_id, values_input + block_offset, @@ -513,7 +512,7 @@ struct radix_sort_and_scatter_helper } else { - if ROCPRIM_IF_CONSTEXPR(load_warp_striped) + if constexpr(load_warp_striped) { block_load_direct_warp_striped(flat_id, values_input + block_offset, @@ -596,13 +595,11 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_single(KeysInputIterator keys_i unsigned int bit, unsigned int current_radix_bits) { - using key_type = typename std::iterator_traits::value_type; + using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; - using sort_single_helper = radix_sort_single_helper< - BlockSize, ItemsPerThread, Descending, - key_type, value_type - >; + using sort_single_helper + = radix_sort_single_helper; ROCPRIM_SHARED_MEMORY typename sort_single_helper::storage_type storage; @@ -619,8 +616,8 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void sort_single(KeysInputIterator keys_i template ROCPRIM_DEVICE ROCPRIM_INLINE -auto compare_nan_sensitive(const T& a, const T& b) - -> typename std::enable_if::value, bool>::type +auto compare_nan_sensitive(const T& a, const T& b) -> + typename std::enable_if::value, bool>::type { // Beware: the performance of this function is extremely vulnerable to refactoring. // Always check benchmark_device_segmented_radix_sort and benchmark_device_radix_sort @@ -644,7 +641,8 @@ auto compare_nan_sensitive(const T& a, const T& b) } template -ROCPRIM_DEVICE auto compare_nan_sensitive(const T& a, const T& b) -> +ROCPRIM_DEVICE +auto compare_nan_sensitive(const T& a, const T& b) -> typename std::enable_if::value, bool>::type { return a > b; @@ -656,7 +654,8 @@ struct radix_merge_compare; template struct radix_merge_compare { - ROCPRIM_DEVICE bool operator()(const T& a, const T& b) const + ROCPRIM_DEVICE + bool operator()(const T& a, const T& b) const { return compare_nan_sensitive(b, a); } @@ -665,7 +664,8 @@ struct radix_merge_compare template struct radix_merge_compare { - ROCPRIM_DEVICE bool operator()(const T& a, const T& b) const + ROCPRIM_DEVICE + bool operator()(const T& a, const T& b) const { return compare_nan_sensitive(a, b); } @@ -682,13 +682,14 @@ struct radix_merge_compare { T radix_mask_upper = (T(1) << (current_radix_bits + start_bit)) - 1; T radix_mask_bottom = (T(1) << start_bit) - 1; - radix_mask = radix_mask_upper ^ radix_mask_bottom; + radix_mask = radix_mask_upper ^ radix_mask_bottom; } - ROCPRIM_DEVICE bool operator()(const T& a, const T& b) const + ROCPRIM_DEVICE + bool operator()(const T& a, const T& b) const { - const T masked_key_a = a & radix_mask; - const T masked_key_b = b & radix_mask; + const T masked_key_a = a & radix_mask; + const T masked_key_b = b & radix_mask; return Descending ? masked_key_a > masked_key_b : masked_key_b > masked_key_a; } }; @@ -706,9 +707,11 @@ struct radix_merge_compare : decomposer_(decomposer), start_bit_(start_bit), radix_bits_(current_radix_bits) {} - ROCPRIM_HOST_DEVICE bool operator()(T lhs, T rhs) const + ROCPRIM_HOST_DEVICE + bool operator()(T lhs, T rhs) const { - using codec_t = radix_key_codec; + using codec_t + = decltype(::rocprim::traits::get().template radix_key_codec()); // Encoding the values considers the ascending / descending nature of the sort codec_t::encode_inplace(lhs, decomposer_); @@ -773,7 +776,8 @@ struct onesweep_histograms_helper = radix_size * max_digit_places * atomic_stripes; using counter_type = uint32_t; - using key_codec = radix_key_codec; + using key_codec + = decltype(::rocprim::traits::get().template radix_key_codec()); struct storage_type { @@ -787,8 +791,8 @@ struct onesweep_histograms_helper return (place * radix_size + digit) * atomic_stripes + stripe_index; } - ROCPRIM_DEVICE ROCPRIM_INLINE void clear_histogram(const unsigned int flat_id, - storage_type& storage) + ROCPRIM_DEVICE ROCPRIM_INLINE + void clear_histogram(const unsigned int flat_id, storage_type& storage) { for(unsigned int i = flat_id; i < histogram_counters; i += BlockSize) { @@ -811,7 +815,7 @@ struct onesweep_histograms_helper KeyType keys[ItemsPerThread]; // Load using a striped arrangement, the order doesn't matter here. - if ROCPRIM_IF_CONSTEXPR(IsFull) + if constexpr(IsFull) { block_load_direct_striped(flat_id, keys_input, keys); } @@ -832,7 +836,7 @@ struct onesweep_histograms_helper key_codec::encode_inplace(keys[i], decomposer); } - if ROCPRIM_IF_CONSTEXPR(AllBits) + if constexpr(AllBits) { ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; ++i) @@ -916,7 +920,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void onesweep_histograms(KeysInputIterator const unsigned int begin_bit, const unsigned int end_bit) { - using key_type = typename std::iterator_traits::value_type; + using key_type = typename std::iterator_traits::value_type; using count_helper_type = onesweep_histograms_helper(); + const Offset block_id = ::rocprim::detail::block_id<0>(); const Offset block_offset = block_id * ItemsPerThread * BlockSize; ROCPRIM_SHARED_MEMORY typename count_helper_type::storage_type storage; @@ -968,7 +972,8 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void onesweep_histograms(KeysInputIterator } template -ROCPRIM_DEVICE void onesweep_scan_histograms(Offset* global_digit_offsets) +ROCPRIM_DEVICE +void onesweep_scan_histograms(Offset* global_digit_offsets) { using block_scan_type = block_scan; @@ -1013,23 +1018,27 @@ struct onesweep_lookback_state : state(static_cast(status) | value) {} - ROCPRIM_DEVICE ROCPRIM_INLINE underlying_type value() const + ROCPRIM_DEVICE ROCPRIM_INLINE + underlying_type value() const { return this->state & value_mask; } - ROCPRIM_DEVICE ROCPRIM_INLINE prefix_flag status() const + ROCPRIM_DEVICE ROCPRIM_INLINE + prefix_flag status() const { return static_cast(this->state & status_mask); } - ROCPRIM_DEVICE ROCPRIM_INLINE static onesweep_lookback_state load(onesweep_lookback_state* ptr) + ROCPRIM_DEVICE ROCPRIM_INLINE + static onesweep_lookback_state load(onesweep_lookback_state* ptr) { underlying_type state = ::rocprim::detail::atomic_load(&ptr->state); return onesweep_lookback_state(state); } - ROCPRIM_DEVICE ROCPRIM_INLINE void store(onesweep_lookback_state* ptr) const + ROCPRIM_DEVICE ROCPRIM_INLINE + void store(onesweep_lookback_state* ptr) const { ::rocprim::detail::atomic_store(&ptr->state, this->state); } @@ -1050,7 +1059,8 @@ struct onesweep_iteration_helper static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; static constexpr bool with_values = !std::is_same::value; - using key_codec = radix_key_codec; + using key_codec + = decltype(::rocprim::traits::get().template radix_key_codec()); using radix_rank_type = ::rocprim::block_radix_rank; static constexpr bool load_warp_striped @@ -1066,8 +1076,8 @@ struct onesweep_iteration_helper Offset global_digit_offsets[radix_size]; union { - Key ordered_block_keys[items_per_block]; - Value ordered_block_values[items_per_block]; + Key ordered_block_keys[items_per_block]; + Value ordered_block_values[items_per_block]; }; }; }; @@ -1081,18 +1091,19 @@ struct onesweep_iteration_helper class KeysOutputIterator, class ValuesInputIterator, class ValuesOutputIterator> - ROCPRIM_DEVICE void onesweep(KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - ValuesOutputIterator values_output, - Offset* global_digit_offsets_in, - Offset* global_digit_offsets_out, - onesweep_lookback_state* lookback_states, - Decomposer decomposer, - const unsigned int bit, - const unsigned int current_radix_bits, - const unsigned int valid_items, - storage_type_& storage) + ROCPRIM_DEVICE + void onesweep(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + ValuesOutputIterator values_output, + Offset* global_digit_offsets_in, + Offset* global_digit_offsets_out, + onesweep_lookback_state* lookback_states, + Decomposer decomposer, + const unsigned int bit, + const unsigned int current_radix_bits, + const unsigned int valid_items, + storage_type_& storage) { const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); const unsigned int block_id = ::rocprim::detail::block_id<0>(); @@ -1100,9 +1111,9 @@ struct onesweep_iteration_helper // Load keys into private memory, and encode them to unsigned integers. Key keys[ItemsPerThread]; - if ROCPRIM_IF_CONSTEXPR(IsFull) + if constexpr(IsFull) { - if ROCPRIM_IF_CONSTEXPR(load_warp_striped) + if constexpr(load_warp_striped) { block_load_direct_warp_striped(flat_id, keys_input + block_offset, keys); } @@ -1120,7 +1131,7 @@ struct onesweep_iteration_helper // it does not matter. It does cause the final digit offset to be increased past its end, // but again this does not matter since this is the last iteration in which it will be used anyway. const Key out_of_bounds = key_codec::get_out_of_bounds_key(decomposer); - if ROCPRIM_IF_CONSTEXPR(load_warp_striped) + if constexpr(load_warp_striped) { block_load_direct_warp_striped(flat_id, keys_input + block_offset, @@ -1241,9 +1252,9 @@ struct onesweep_iteration_helper if(with_values) { Value values[ItemsPerThread]; - if ROCPRIM_IF_CONSTEXPR(IsFull) + if constexpr(IsFull) { - if ROCPRIM_IF_CONSTEXPR(load_warp_striped) + if constexpr(load_warp_striped) { block_load_direct_warp_striped(flat_id, values_input + block_offset, values); } @@ -1254,7 +1265,7 @@ struct onesweep_iteration_helper } else { - if ROCPRIM_IF_CONSTEXPR(load_warp_striped) + if constexpr(load_warp_striped) { block_load_direct_warp_striped(flat_id, values_input + block_offset, @@ -1366,7 +1377,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void Decomposer>; constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; - const unsigned int block_id = ::rocprim::detail::block_id<0>(); + const unsigned int block_id = ::rocprim::detail::block_id<0>(); ROCPRIM_SHARED_MEMORY typename onesweep_iteration_helper_type::storage_type storage; diff --git a/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp b/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp index ed36cbe2e..e2e8ecd25 100644 --- a/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp +++ b/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp @@ -48,9 +48,8 @@ namespace reduce_by_key { template -using accumulator_type_t = - typename invoke_result_binary_op<::rocprim::detail::value_type_t, - BinaryOp>::type; +using accumulator_type_t + = ::rocprim::accumulator_t>; template using wrapped_type_t = rocprim::tuple; @@ -85,19 +84,20 @@ struct load_helper }; template - ROCPRIM_DEVICE void load_keys_values(KeyIterator tile_keys, - ValueIterator tile_values, - const bool is_global_last_tile, - const unsigned int valid_in_global_last_tile, - KeyType (&keys)[ItemsPerThread], - AccumulatorType (&values)[ItemsPerThread], - storage_type& storage) + ROCPRIM_DEVICE + void load_keys_values(KeyIterator tile_keys, + ValueIterator tile_values, + const bool is_global_last_tile, + const unsigned int valid_in_global_last_tile, + KeyType (&keys)[ItemsPerThread], + AccumulatorType (&values)[ItemsPerThread], + storage_type& storage) { if(!is_global_last_tile) { block_load_keys{}.load(tile_keys, keys, storage.keys); - if ROCPRIM_IF_CONSTEXPR(requires_inner_sync) + if constexpr(requires_inner_sync) { ::rocprim::syncthreads(); } @@ -106,7 +106,7 @@ struct load_helper else { block_load_keys{}.load(tile_keys, keys, valid_in_global_last_tile, storage.keys); - if ROCPRIM_IF_CONSTEXPR(requires_inner_sync) + if constexpr(requires_inner_sync) { ::rocprim::syncthreads(); } @@ -125,14 +125,15 @@ struct discontinuity_helper using storage_type = typename block_discontinuity_type::storage_type; template - ROCPRIM_DEVICE void flag_heads(KeyIterator tile_keys, - const KeyType (&keys)[ItemsPerThread], - CompareFunction compare, - unsigned int (&head_flags)[ItemsPerThread], - const bool is_global_first_tile, - const bool is_global_last_tile, - const size_t remaining, - storage_type& storage) + ROCPRIM_DEVICE + void flag_heads(KeyIterator tile_keys, + const KeyType (&keys)[ItemsPerThread], + CompareFunction compare, + unsigned int (&head_flags)[ItemsPerThread], + const bool is_global_first_tile, + const bool is_global_last_tile, + const size_t remaining, + storage_type& storage) { if(is_global_last_tile) { @@ -184,13 +185,14 @@ struct scatter_helper ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_POP template - ROCPRIM_DEVICE void scatter(ValueIterator tile_values, - ValueFunction&& values, - const Flag (&is_selected)[ItemsPerThread], - IndexFunction&& block_indices, - const unsigned int selected_in_tile, - const unsigned int flat_thread_id, - storage_type& storage) + ROCPRIM_DEVICE + void scatter(ValueIterator tile_values, + ValueFunction&& values, + const Flag (&is_selected)[ItemsPerThread], + IndexFunction&& block_indices, + const unsigned int selected_in_tile, + const unsigned int flat_thread_id, + storage_type& storage) { // Check if all threads in the warp are selecting the same location (selected or rejected) uint8_t all_check = 3; // [true, true] diff --git a/rocprim/include/rocprim/device/detail/device_scan.hpp b/rocprim/include/rocprim/device/detail/device_scan.hpp index c80842a75..8a36a830b 100644 --- a/rocprim/include/rocprim/device/detail/device_scan.hpp +++ b/rocprim/include/rocprim/device/detail/device_scan.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -51,16 +51,17 @@ namespace detail // Helper functions for performing exclusive or inclusive // block scan in single_scan. template -ROCPRIM_DEVICE ROCPRIM_INLINE auto single_scan_block_scan(T (&input)[ItemsPerThread], - T (&output)[ItemsPerThread], - T initial_value, - typename BlockScan::storage_type& storage, - BinaryFunction scan_op) -> - typename std::enable_if::type +ROCPRIM_DEVICE ROCPRIM_INLINE +auto single_scan_block_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T initial_value, + typename BlockScan::storage_type& storage, + BinaryFunction scan_op) -> typename std::enable_if::type { BlockScan().exclusive_scan(input, // input output, // output @@ -70,26 +71,38 @@ ROCPRIM_DEVICE ROCPRIM_INLINE auto single_scan_block_scan(T (&input)[ItemsPerThr } template -ROCPRIM_DEVICE ROCPRIM_INLINE auto single_scan_block_scan(T (&input)[ItemsPerThread], - T (&output)[ItemsPerThread], - T initial_value, - typename BlockScan::storage_type& storage, - BinaryFunction scan_op) -> - typename std::enable_if::type +ROCPRIM_DEVICE ROCPRIM_INLINE +auto single_scan_block_scan(T (&input)[ItemsPerThread], + T (&output)[ItemsPerThread], + T initial_value, + typename BlockScan::storage_type& storage, + BinaryFunction scan_op) -> typename std::enable_if::type { - (void)initial_value; - BlockScan().inclusive_scan(input, // input - output, // output - storage, - scan_op); + if constexpr(UseInitialValue) + { + BlockScan().inclusive_scan(input, // input + initial_value, + output, // output + storage, + scan_op); + } + else + { + BlockScan().inclusive_scan(input, // input + output, // output + storage, + scan_op); + } } template(*(input - 1))); + } else if(flat_block_thread_id == 0) + { values[0] = scan_op(previous_last_element[0], values[0]); + } } AccType reduction; - lookback_block_scan(values, // input/output - initial_value, - reduction, - storage.scan, - scan_op); - if(flat_block_thread_id == 0) + // Since `override_first_value` isn't a constexpr and there's no exclusive block scan + // overload without an initial_value parameter, the two scan types need separate + // code paths, this duplicates a bit of code. + if constexpr(Exclusive) { - scan_state.set_complete(flat_block_id, reduction); + lookback_block_scan(values, // input/output + initial_value, + reduction, + storage.scan, + scan_op); + + // Reduction should not contain initial value. + if(flat_block_thread_id == 0) + { + scan_state.set_complete(flat_block_id, reduction); + } + } + else + { + if(UseInitialValue && !override_first_value) + { + lookback_block_scan(values, // input/output + initial_value, + reduction, + storage.scan, + scan_op); + } + else + { + // Only use the initial value on the first iteration + lookback_block_scan(values, // input/output + reduction, + storage.scan, + scan_op); + } + + // Reduction should include initial value. We can avoid block-wide communication + // communication by letting the thread that has the last element of the scan + // write it to memory. + if(flat_block_thread_id == block_size - 1) + { + scan_state.set_complete(flat_block_id, values[items_per_thread - 1]); + } } } else diff --git a/rocprim/include/rocprim/device/detail/device_scan_by_key.hpp b/rocprim/include/rocprim/device/detail/device_scan_by_key.hpp index a04551d68..b0f933b21 100644 --- a/rocprim/include/rocprim/device/detail/device_scan_by_key.hpp +++ b/rocprim/include/rocprim/device/detail/device_scan_by_key.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -367,11 +367,21 @@ namespace detail } wrapped_type reduction; - lookback_block_scan(wrapped_values, - wrapped_initial_value, - reduction, - storage.scan, - wrapped_op); + if constexpr(Exclusive) + { + lookback_block_scan(wrapped_values, + wrapped_initial_value, + reduction, + storage.scan, + wrapped_op); + } + else + { + lookback_block_scan(wrapped_values, + reduction, + storage.scan, + wrapped_op); + } if(flat_thread_id == 0) { diff --git a/rocprim/include/rocprim/device/detail/device_scan_common.hpp b/rocprim/include/rocprim/device/detail/device_scan_common.hpp index 4207f986a..08e715b0a 100644 --- a/rocprim/include/rocprim/device/detail/device_scan_common.hpp +++ b/rocprim/include/rocprim/device/detail/device_scan_common.hpp @@ -139,24 +139,43 @@ ROCPRIM_KERNEL } #ifndef DOXYGEN_SHOULD_SKIP_THIS - template - ROCPRIM_DEVICE ROCPRIM_INLINE auto - lookback_block_scan(T (&values)[ItemsPerThread], - T /* initial_value */, - T& reduction, - typename BlockScan::storage_type& storage, - BinaryFunction scan_op) -> typename std::enable_if::type - { - BlockScan().inclusive_scan(values, // input - values, // output - reduction, - storage, - scan_op); - } +template + ROCPRIM_DEVICE ROCPRIM_INLINE +auto lookback_block_scan(T (&values)[ItemsPerThread], + T& reduction, + typename BlockScan::storage_type& storage, + BinaryFunction scan_op) -> typename std::enable_if::type +{ + BlockScan().inclusive_scan(values, // input + values, // output + reduction, + storage, + scan_op); +} + +template + ROCPRIM_DEVICE ROCPRIM_INLINE +auto lookback_block_scan(T (&values)[ItemsPerThread], + T initial_value, + T& reduction, + typename BlockScan::storage_type& storage, + BinaryFunction scan_op) -> typename std::enable_if::type +{ + BlockScan().inclusive_scan(values, // input + initial_value, + values, // output + reduction, + storage, + scan_op); +} template >>( @@ -419,7 +419,7 @@ hipError_t search_impl(void* temporary_storage, } else { - if ROCPRIM_IF_CONSTEXPR(find_first) + if constexpr(find_first) { start_timer(); search_kernels::search_kernel<<>>( @@ -445,7 +445,7 @@ hipError_t search_impl(void* temporary_storage, } } - if ROCPRIM_IF_CONSTEXPR(!find_first) + if constexpr(!find_first) { start_timer(); search_kernels::reverse_index_kernel<<<1, 1, 0, stream>>>(tmp_output, size, keys_size); diff --git a/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp index cc230cd07..2d5264f36 100644 --- a/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp @@ -21,8 +21,8 @@ #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_RADIX_SORT_HPP_ #define ROCPRIM_DEVICE_DETAIL_DEVICE_SEGMENTED_RADIX_SORT_HPP_ -#include #include +#include #include "../../config.hpp" #include "../../detail/various.hpp" @@ -33,15 +33,13 @@ #include "../../types.hpp" #include "../../block/block_load.hpp" -#include "../../block/block_store.hpp" #include "../../block/block_scan.hpp" +#include "../../block/block_store.hpp" #include "../../warp/detail/warp_sort_stable.hpp" #include "../../warp/warp_load.hpp" #include "../../warp/warp_store.hpp" -#include "../../thread/radix_key_codec.hpp" - #include "../device_segmented_radix_sort_config.hpp" #include "device_radix_sort.hpp" @@ -50,30 +48,32 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class Key, - class Value, - unsigned int WarpSize, - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int RadixBits, - bool Descending -> +template class segmented_radix_sort_helper { static constexpr unsigned int radix_size = 1 << RadixBits; - using key_type = Key; + using key_type = Key; using value_type = Value; - using count_helper_type = radix_digit_count_helper; - using scan_type = typename ::rocprim::block_scan; - using sort_and_scatter_helper = radix_sort_and_scatter_helper< - BlockSize, ItemsPerThread, RadixBits, Descending, - key_type, value_type, unsigned int>; + using count_helper_type + = radix_digit_count_helper; + using scan_type = typename ::rocprim::block_scan; + using sort_and_scatter_helper = radix_sort_and_scatter_helper; public: - union storage_type { typename count_helper_type::storage_type count_helper; @@ -111,42 +111,54 @@ class segmented_radix_sort_helper { if(to_output) { - sort( - keys_input, keys_output, values_input, values_output, - begin_offset, end_offset, - bit, current_radix_bits, - storage - ); + sort(keys_input, + keys_output, + values_input, + values_output, + begin_offset, + end_offset, + bit, + current_radix_bits, + storage); } else { - sort( - keys_input, keys_tmp, values_input, values_tmp, - begin_offset, end_offset, - bit, current_radix_bits, - storage - ); + sort(keys_input, + keys_tmp, + values_input, + values_tmp, + begin_offset, + end_offset, + bit, + current_radix_bits, + storage); } } else { if(to_output) { - sort( - keys_tmp, keys_output, values_tmp, values_output, - begin_offset, end_offset, - bit, current_radix_bits, - storage - ); + sort(keys_tmp, + keys_output, + values_tmp, + values_output, + begin_offset, + end_offset, + bit, + current_radix_bits, + storage); } else { - sort( - keys_output, keys_tmp, values_output, values_tmp, - begin_offset, end_offset, - bit, current_radix_bits, - storage - ); + sort(keys_output, + keys_tmp, + values_output, + values_tmp, + begin_offset, + end_offset, + bit, + current_radix_bits, + storage); } } } @@ -173,24 +185,24 @@ class segmented_radix_sort_helper const bool is_first_iteration = (bit == begin_bit); - key_type * current_keys_input; - key_type * current_keys_output; - value_type * current_values_input; - value_type * current_values_output; + key_type* current_keys_input; + key_type* current_keys_output; + value_type* current_values_input; + value_type* current_values_output; if(is_first_iteration) { if(to_output) { - current_keys_input = keys_input; - current_keys_output = keys_output; - current_values_input = values_input; + current_keys_input = keys_input; + current_keys_output = keys_output; + current_values_input = values_input; current_values_output = values_output; } else { - current_keys_input = keys_input; - current_keys_output = keys_tmp; - current_values_input = values_input; + current_keys_input = keys_input; + current_keys_output = keys_tmp; + current_values_input = values_input; current_values_output = values_tmp; } } @@ -198,29 +210,31 @@ class segmented_radix_sort_helper { if(to_output) { - current_keys_input = keys_tmp; - current_keys_output = keys_output; - current_values_input = values_tmp; + current_keys_input = keys_tmp; + current_keys_output = keys_output; + current_values_input = values_tmp; current_values_output = values_output; } else { - current_keys_input = keys_output; - current_keys_output = keys_tmp; - current_values_input = values_output; + current_keys_input = keys_output; + current_keys_output = keys_tmp; + current_values_input = values_output; current_values_output = values_tmp; } } - sort( - current_keys_input, current_keys_output, current_values_input, current_values_output, - begin_offset, end_offset, - bit, current_radix_bits, - storage - ); + sort(current_keys_input, + current_keys_output, + current_values_input, + current_values_output, + begin_offset, + end_offset, + bit, + current_radix_bits, + storage); } private: - template< class KeysInputIterator, class KeysOutputIterator, @@ -239,13 +253,13 @@ class segmented_radix_sort_helper storage_type& storage) { unsigned int digit_count; - count_helper_type().count_digits( - keys_input, - begin_offset, end_offset, - bit, current_radix_bits, - storage.count_helper, - digit_count - ); + count_helper_type().count_digits(keys_input, + begin_offset, + end_offset, + bit, + current_radix_bits, + storage.count_helper, + digit_count); unsigned int digit_start; scan_type().exclusive_scan(digit_count, digit_start, 0); @@ -253,28 +267,30 @@ class segmented_radix_sort_helper ::rocprim::syncthreads(); - sort_and_scatter_helper().sort_and_scatter( - keys_input, keys_output, values_input, values_output, - begin_offset, end_offset, - bit, current_radix_bits, - digit_start, - storage.sort_and_scatter_helper - ); + sort_and_scatter_helper().sort_and_scatter(keys_input, + keys_output, + values_input, + values_output, + begin_offset, + end_offset, + bit, + current_radix_bits, + digit_start, + storage.sort_and_scatter_helper); ::rocprim::syncthreads(); } }; -template< - class Key, - class Value, - unsigned int BlockSize, - unsigned int ItemsPerThread, - bool Descending -> +template class segmented_radix_sort_single_block_helper { - using key_codec = radix_key_codec; + using key_codec + = decltype(::rocprim::traits::get().template radix_key_codec()); using bit_key_type = typename key_codec::bit_key_type; using sort_type = ::rocprim::block_radix_sort::value; public: - union storage_type { typename sort_type::storage_type sort; @@ -314,21 +329,27 @@ class segmented_radix_sort_single_block_helper { if(to_output) { - sort( - keys_input, keys_output, values_input, values_output, - begin_offset, end_offset, - begin_bit, end_bit, - storage - ); + sort(keys_input, + keys_output, + values_input, + values_output, + begin_offset, + end_offset, + begin_bit, + end_bit, + storage); } else { - sort( - keys_input, keys_tmp, values_input, values_tmp, - begin_offset, end_offset, - begin_bit, end_bit, - storage - ); + sort(keys_input, + keys_tmp, + values_input, + values_tmp, + begin_offset, + end_offset, + begin_bit, + end_bit, + storage); } } @@ -347,30 +368,31 @@ class segmented_radix_sort_single_block_helper unsigned int end_bit, storage_type& storage) { - sort( - keys_input, (to_output ? keys_output : keys_tmp), values_input, (to_output ? values_output : values_tmp), - begin_offset, end_offset, - begin_bit, end_bit, - storage - ); + sort(keys_input, + (to_output ? keys_output : keys_tmp), + values_input, + (to_output ? values_output : values_tmp), + begin_offset, + end_offset, + begin_bit, + end_bit, + storage); } - template< - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator - > + template ROCPRIM_DEVICE ROCPRIM_INLINE - bool sort(KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, + bool sort(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, ValuesOutputIterator values_output, - unsigned int begin_offset, - unsigned int end_offset, - unsigned int begin_bit, - unsigned int end_bit, - storage_type& storage) + unsigned int begin_offset, + unsigned int end_offset, + unsigned int begin_bit, + unsigned int end_bit, + storage_type& storage) { constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; @@ -388,13 +410,16 @@ class segmented_radix_sort_single_block_helper } // Recursively check if it is possible to sort the segment using fewer items per thread - const bool processed_by_shorter = - shorter_single_block_helper().sort( - keys_input, keys_output, values_input, values_output, - begin_offset, end_offset, - begin_bit, end_bit, - reinterpret_cast(storage) - ); + const bool processed_by_shorter = shorter_single_block_helper().sort( + keys_input, + keys_output, + values_input, + values_output, + begin_offset, + end_offset, + begin_bit, + end_bit, + reinterpret_cast(storage)); if(processed_by_shorter) { return true; @@ -434,7 +459,7 @@ class segmented_radix_sort_single_block_helper keys_output + begin_offset, keys, valid_count); - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { block_store_direct_striped(flat_id, values_output + begin_offset, @@ -446,24 +471,17 @@ class segmented_radix_sort_single_block_helper } }; -template< - class Key, - class Value, - unsigned int BlockSize, - bool Descending -> +template class segmented_radix_sort_single_block_helper { public: + struct storage_type + {}; - struct storage_type { }; - - template< - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator - > + template ROCPRIM_DEVICE ROCPRIM_INLINE bool sort(KeysInputIterator, KeysOutputIterator, @@ -509,13 +527,12 @@ template ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Args&&...) - { - } + {} }; template @@ -528,9 +545,11 @@ class segmented_warp_sort_helper< std::enable_if_t::value>> { static constexpr unsigned int logical_warp_size = Config::logical_warp_size; - static constexpr unsigned int items_per_thread = Config::items_per_thread; + static constexpr unsigned int items_per_thread = Config::items_per_thread; + + using key_codec + = decltype(::rocprim::traits::get().template radix_key_codec()); - using key_codec = ::rocprim::radix_key_codec; using bit_key_type = typename key_codec::bit_key_type; using keys_load_type = ::rocprim::warp_load; @@ -550,11 +569,11 @@ class segmented_warp_sort_helper< union storage_type { - typename keys_load_type::storage_type keys_load; - typename values_load_type::storage_type values_load; - typename keys_store_type::storage_type keys_store; + typename keys_load_type::storage_type keys_load; + typename values_load_type::storage_type values_load; + typename keys_store_type::storage_type keys_store; typename values_store_type::storage_type values_store; - typename sort_type::storage_type sort; + typename sort_type::storage_type sort; }; private: @@ -614,34 +633,39 @@ class segmented_warp_sort_helper< } public: - template< - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator - > + template ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, + void sort(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, ValuesOutputIterator values_output, - unsigned int begin_offset, - unsigned int end_offset, - unsigned int begin_bit, - unsigned int end_bit, - storage_type& storage) + unsigned int begin_offset, + unsigned int end_offset, + unsigned int begin_bit, + unsigned int end_bit, + storage_type& storage) { - const unsigned int num_items = end_offset - begin_offset; + const unsigned int num_items = end_offset - begin_offset; const Key out_of_bounds = key_codec::get_out_of_bounds_key(); Key keys[items_per_thread]; Value values[items_per_thread]; - keys_load_type().load(keys_input + begin_offset, keys, num_items, out_of_bounds, storage.keys_load); + keys_load_type().load(keys_input + begin_offset, + keys, + num_items, + out_of_bounds, + storage.keys_load); - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { ::rocprim::wave_barrier(); - values_load_type().load(values_input + begin_offset, values, num_items, storage.values_load); + values_load_type().load(values_input + begin_offset, + values, + num_items, + storage.values_load); } ::rocprim::wave_barrier(); @@ -650,10 +674,13 @@ class segmented_warp_sort_helper< ::rocprim::wave_barrier(); keys_store_type().store(keys_output + begin_offset, keys, num_items, storage.keys_store); - if ROCPRIM_IF_CONSTEXPR(with_values) + if constexpr(with_values) { ::rocprim::wave_barrier(); - values_store_type().store(values_output + begin_offset, values, num_items, storage.values_store); + values_store_type().store(values_output + begin_offset, + values, + num_items, + storage.values_store); } } @@ -677,21 +704,27 @@ class segmented_warp_sort_helper< { if(to_output) { - sort( - keys_input, keys_output, values_input, values_output, - begin_offset, end_offset, - begin_bit, end_bit, - storage - ); + sort(keys_input, + keys_output, + values_input, + values_output, + begin_offset, + end_offset, + begin_bit, + end_bit, + storage); } else { - sort( - keys_input, keys_tmp, values_input, values_tmp, - begin_offset, end_offset, - begin_bit, end_bit, - storage - ); + sort(keys_input, + keys_tmp, + values_input, + values_tmp, + begin_offset, + end_offset, + begin_bit, + end_bit, + storage); } } }; @@ -721,29 +754,29 @@ void segmented_sort(KeysInputIterator keys_input, { static constexpr segmented_radix_sort_config_params params = device_params(); - static constexpr unsigned int long_radix_bits = params.long_radix_bits; + static constexpr unsigned int radix_bits = params.radix_bits; static constexpr unsigned int block_size = params.kernel_config.block_size; static constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; static constexpr unsigned int items_per_block = block_size * items_per_thread; static constexpr bool warp_sort_enabled = params.enable_unpartitioned_warp_sort; - using key_type = typename std::iterator_traits::value_type; + using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; - using single_block_helper_type = segmented_radix_sort_single_block_helper< - key_type, value_type, - block_size, items_per_thread, - Descending - >; + using single_block_helper_type = segmented_radix_sort_single_block_helper; using long_radix_helper_type = segmented_radix_sort_helper; - using warp_sort_helper_type = segmented_warp_sort_helper< + using warp_sort_helper_type = segmented_warp_sort_helper< select_warp_sort_helper_config_t(); const unsigned int begin_offset = begin_offsets[segment_id]; - const unsigned int end_offset = end_offsets[segment_id]; + const unsigned int end_offset = end_offsets[segment_id]; // Empty segment if(end_offset <= begin_offset) @@ -776,15 +809,21 @@ void segmented_sort(KeysInputIterator keys_input, if(end_offset - begin_offset > items_per_block) { // Large segment - for(unsigned int bit = begin_bit; bit < end_bit; bit += long_radix_bits) + for(unsigned int bit = begin_bit; bit < end_bit; bit += radix_bits) { - long_radix_helper_type().sort( - keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, - to_output, - begin_offset, end_offset, - bit, begin_bit, end_bit, - storage.long_radix_helper - ); + long_radix_helper_type().sort(keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + begin_offset, + end_offset, + bit, + begin_bit, + end_bit, + storage.long_radix_helper); to_output = !to_output; } @@ -850,26 +889,26 @@ void segmented_sort_large(KeysInputIterator keys_input, { static constexpr segmented_radix_sort_config_params params = device_params(); - static constexpr unsigned int long_radix_bits = params.long_radix_bits; + static constexpr unsigned int radix_bits = params.radix_bits; static constexpr unsigned int block_size = params.kernel_config.block_size; static constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; static constexpr unsigned int items_per_block = block_size * items_per_thread; - using key_type = typename std::iterator_traits::value_type; + using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; - using single_block_helper_type = segmented_radix_sort_single_block_helper< - key_type, value_type, - block_size, items_per_thread, - Descending - >; + using single_block_helper_type = segmented_radix_sort_single_block_helper; using long_radix_helper_type = segmented_radix_sort_helper; ROCPRIM_SHARED_MEMORY union @@ -878,10 +917,10 @@ void segmented_sort_large(KeysInputIterator keys_input, typename long_radix_helper_type::storage_type long_radix_helper; } storage; - const unsigned int block_id = ::rocprim::detail::block_id<0>(); - const unsigned int segment_id = segment_indices[block_id]; + const unsigned int block_id = ::rocprim::detail::block_id<0>(); + const unsigned int segment_id = segment_indices[block_id]; const unsigned int begin_offset = begin_offsets[segment_id]; - const unsigned int end_offset = end_offsets[segment_id]; + const unsigned int end_offset = end_offsets[segment_id]; if(end_offset <= begin_offset) { @@ -890,15 +929,21 @@ void segmented_sort_large(KeysInputIterator keys_input, if(end_offset - begin_offset > items_per_block) { - for(unsigned int bit = begin_bit; bit < end_bit; bit += long_radix_bits) + for(unsigned int bit = begin_bit; bit < end_bit; bit += radix_bits) { - long_radix_helper_type().sort( - keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, - to_output, - begin_offset, end_offset, - bit, begin_bit, end_bit, - storage.long_radix_helper - ); + long_radix_helper_type().sort(keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + begin_offset, + end_offset, + bit, + begin_bit, + end_bit, + storage.long_radix_helper); to_output = !to_output; } @@ -1032,7 +1077,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void segmented_sort_medium( "logical_warp_size must be a divisor of block_size"); static constexpr unsigned int warps_per_block = block_size / logical_warp_size; - using key_type = typename std::iterator_traits::value_type; + using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; using warp_sort_helper_type = segmented_warp_sort_helper< @@ -1047,28 +1092,34 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void segmented_sort_medium( ROCPRIM_SHARED_MEMORY typename warp_sort_helper_type::storage_type storage; - const unsigned int block_id = ::rocprim::detail::block_id<0>(); + const unsigned int block_id = ::rocprim::detail::block_id<0>(); const unsigned int logical_warp_id = ::rocprim::detail::logical_warp_id(); - const unsigned int segment_index = block_id * warps_per_block + logical_warp_id; + const unsigned int segment_index = block_id * warps_per_block + logical_warp_id; if(segment_index >= num_segments) { return; } - const unsigned int segment_id = segment_indices[segment_index]; + const unsigned int segment_id = segment_indices[segment_index]; const unsigned int begin_offset = begin_offsets[segment_id]; - const unsigned int end_offset = end_offsets[segment_id]; + const unsigned int end_offset = end_offsets[segment_id]; if(end_offset <= begin_offset) { return; } - warp_sort_helper_type().sort( - keys_input, keys_tmp, keys_output, - values_input, values_tmp, values_output, - to_output, begin_offset, end_offset, - begin_bit, end_bit, storage - ); + warp_sort_helper_type().sort(keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + begin_offset, + end_offset, + begin_bit, + end_bit, + storage); } } // end namespace detail diff --git a/rocprim/include/rocprim/device/detail/device_transform.hpp b/rocprim/include/rocprim/device/detail/device_transform.hpp index 1f966e264..bc514b3c7 100644 --- a/rocprim/include/rocprim/device/detail/device_transform.hpp +++ b/rocprim/include/rocprim/device/detail/device_transform.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,14 +21,14 @@ #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_TRANSFORM_HPP_ #define ROCPRIM_DEVICE_DETAIL_DEVICE_TRANSFORM_HPP_ -#include #include +#include #include "../../config.hpp" #include "../../detail/various.hpp" -#include "../../intrinsics.hpp" #include "../../functional.hpp" +#include "../../intrinsics.hpp" #include "../../types.hpp" #include "../../block/block_load.hpp" @@ -39,74 +39,162 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { +template +struct unpack_nary_op +{ + using result_type = typename ::rocprim::invoke_result::type; + + ROCPRIM_HOST_DEVICE inline unpack_nary_op() = default; + + ROCPRIM_HOST_DEVICE inline unpack_nary_op(Function op) : op_(op) {} + + ROCPRIM_HOST_DEVICE inline ~unpack_nary_op() = default; + + ROCPRIM_HOST_DEVICE + inline result_type + operator()(const ::rocprim::tuple& t) const + { + return apply_impl(t, std::index_sequence_for{}); + } + +private: + Function op_; + + template + ROCPRIM_HOST_DEVICE + inline result_type apply_impl(const ::rocprim::tuple& t, + std::index_sequence) const + { + return op_(::rocprim::get(t)...); + } +}; + // Wrapper for unpacking tuple to be used with BinaryFunction. // See transform function which accepts two input iterators. template -struct unpack_binary_op +using unpack_binary_op = unpack_nary_op; + +template +using dynamic_size_type = std::conditional_t< + (sizeof(T) * ItemsPerThread <= 1), + uint8_t, + std::conditional_t< + (sizeof(T) * ItemsPerThread <= 2), + uint16_t, + std::conditional_t< + (sizeof(T) * ItemsPerThread <= 4), + uint32_t, + std::conditional_t<(sizeof(T) * ItemsPerThread <= 8), uint64_t, uint128_t>>>>; + +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto transform_kernel_impl(InputIterator input, + const size_t input_size, + OutputIterator output, + UnaryFunction transform_op) -> + typename std::enable_if::type { - using result_type = typename ::rocprim::invoke_result::type; + using input_type = typename std::iterator_traits::value_type; + using output_type = typename std::iterator_traits::value_type; + using result_type = + typename std::conditional::value, ResultType, output_type>::type; - ROCPRIM_HOST_DEVICE inline - unpack_binary_op() = default; + constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; - ROCPRIM_HOST_DEVICE inline - unpack_binary_op(BinaryFunction binary_op) : binary_op_(binary_op) - { - } + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int block_offset = flat_block_id * items_per_block; + const unsigned int number_of_blocks = ::rocprim::detail::grid_size<0>(); + const unsigned int valid_in_last_block = input_size - block_offset; - ROCPRIM_HOST_DEVICE inline - ~unpack_binary_op() = default; + input_type input_values[ItemsPerThread]; + result_type output_values[ItemsPerThread]; - ROCPRIM_HOST_DEVICE inline - result_type operator()(const ::rocprim::tuple& t) + if(flat_block_id == (number_of_blocks - 1)) // last block { - return binary_op_(::rocprim::get<0>(t), ::rocprim::get<1>(t)); + block_load_direct_striped(flat_id, + input + block_offset, + input_values, + valid_in_last_block); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + if(BlockSize * i + flat_id < valid_in_last_block) + { + output_values[i] = transform_op(input_values[i]); + } + } + + block_store_direct_striped(flat_id, + output + block_offset, + output_values, + valid_in_last_block); } + else + { + using vec_input_type = dynamic_size_type; + block_load_direct_blocked_cast(flat_id, + input + block_offset, + input_values); -private: - BinaryFunction binary_op_; -}; + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output_values[i] = transform_op(input_values[i]); + } + + using vec_output_type = dynamic_size_type; + block_store_direct_blocked_cast(flat_id, + output + block_offset, + output_values); + } +} -template< - unsigned int BlockSize, - unsigned int ItemsPerThread, - class ResultType, - class InputIterator, - class OutputIterator, - class UnaryFunction -> +template ROCPRIM_DEVICE ROCPRIM_INLINE -void transform_kernel_impl(InputIterator input, - const size_t input_size, +auto transform_kernel_impl(InputIterator input, + const size_t input_size, OutputIterator output, - UnaryFunction transform_op) + UnaryFunction transform_op) -> + typename std::enable_if::type { - using input_type = typename std::iterator_traits::value_type; + using input_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; using result_type = - typename std::conditional< - std::is_void::value, ResultType, output_type - >::type; + typename std::conditional::value, ResultType, output_type>::type; constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; - const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); - const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); - const unsigned int block_offset = flat_block_id * items_per_block; - const unsigned int number_of_blocks = ::rocprim::detail::grid_size<0>(); + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int block_offset = flat_block_id * items_per_block; + const unsigned int number_of_blocks = ::rocprim::detail::grid_size<0>(); const unsigned int valid_in_last_block = input_size - block_offset; - input_type input_values[ItemsPerThread]; + input_type input_values[ItemsPerThread]; result_type output_values[ItemsPerThread]; if(flat_block_id == (number_of_blocks - 1)) // last block { - block_load_direct_striped( - flat_id, - input + block_offset, - input_values, - valid_in_last_block - ); + block_load_direct_striped(flat_id, + input + block_offset, + input_values, + valid_in_last_block); ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) @@ -117,20 +205,14 @@ void transform_kernel_impl(InputIterator input, } } - block_store_direct_striped( - flat_id, - output + block_offset, - output_values, - valid_in_last_block - ); + block_store_direct_striped(flat_id, + output + block_offset, + output_values, + valid_in_last_block); } else { - block_load_direct_striped( - flat_id, - input + block_offset, - input_values - ); + block_load_direct_striped(flat_id, input + block_offset, input_values); ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) @@ -138,15 +220,11 @@ void transform_kernel_impl(InputIterator input, output_values[i] = transform_op(input_values[i]); } - block_store_direct_striped( - flat_id, - output + block_offset, - output_values - ); + block_store_direct_striped(flat_id, output + block_offset, output_values); } } -} // end of detail namespace +} // namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/detail/lookback_scan_state.hpp b/rocprim/include/rocprim/device/detail/lookback_scan_state.hpp index 63b6c21d3..639ac93f7 100644 --- a/rocprim/include/rocprim/device/detail/lookback_scan_state.hpp +++ b/rocprim/include/rocprim/device/detail/lookback_scan_state.hpp @@ -46,8 +46,7 @@ // Global coherence of prefixes_*_values is ensured by atomic_load/atomic_store that bypass // cache. #ifndef ROCPRIM_DETAIL_LOOKBACK_SCAN_STATE_WITHOUT_SLOW_FENCES - #if defined(__HIP_DEVICE_COMPILE__) \ - && (defined(__gfx942__) || defined(__gfx950__)) + #if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx942__) || defined(__gfx950__)) #define ROCPRIM_DETAIL_LOOKBACK_SCAN_STATE_WITHOUT_SLOW_FENCES 1 #else #define ROCPRIM_DETAIL_LOOKBACK_SCAN_STATE_WITHOUT_SLOW_FENCES 0 @@ -108,26 +107,31 @@ enum class lookback_scan_determinism default_determinism = nondeterministic, }; -// lookback_scan_state object keeps track of prefixes status for -// a look-back prefix scan. Initially every prefix can be either -// invalid (padding values) or empty. One thread in a block should -// later set it to partial, and later to complete. +constexpr const int MAX_PAYLOAD_SIZE = ROCPRIM_MAX_ATOMIC_SIZE - 1; + +/// \brief Optimized implementation of lookback scan, which is a parallel inclusive scan primitive for device level. +/// +/// This object keeps track of prefixes status for a look-back prefix scan. Initially every prefix can be +/// either invalid (padding values) or empty. One thread in a block should later set it to partial, and later to complete. +/// +/// \tparam T The accumulator type of the scan operation. +/// \tparam UseSleep [optional] If true, the execution of a wavefront is paused for a short duration, allowing other threads or processes to execute during idle periods. +/// \tparam IsSmall [optional] Dependent on the size of `T`. If it's smaller than 16 bytes, it's set to true. template struct lookback_scan_state; /// Reduce lanes `0-valid_items` and return the result in lane 0. template -ROCPRIM_DEVICE ROCPRIM_INLINE T lookback_reduce_forward_init(F scan_op, - T block_prefix, - unsigned int valid_items) +ROCPRIM_DEVICE ROCPRIM_INLINE +T lookback_reduce_forward_init(F scan_op, T block_prefix, unsigned int valid_items) { T prefix = block_prefix; for(unsigned int i = 0; i < valid_items; ++i) { -#ifdef ROCPRIM_DETAIL_HAS_DPP_WF_ROTATE +#ifdef ROCPRIM_DETAIL_HAS_DPP_WF prefix = warp_move_dpp(prefix); #else - prefix = warp_shuffle_down(prefix, 1); + prefix = warp_shuffle_down(prefix, 1, ::rocprim::arch::wavefront::size()); #endif prefix = scan_op(prefix, block_prefix); } @@ -137,11 +141,11 @@ ROCPRIM_DEVICE ROCPRIM_INLINE T lookback_reduce_forward_init(F scan_o /// Reduce all lanes with the `prefix`, which is taken from lane 0, /// and return the result in lane 0. template -ROCPRIM_DEVICE ROCPRIM_INLINE T lookback_reduce_forward(F scan_op, T prefix, T block_prefix) +ROCPRIM_DEVICE ROCPRIM_INLINE +T lookback_reduce_forward(F scan_op, T prefix, T block_prefix) { #ifdef ROCPRIM_DETAIL_HAS_DPP_WF - ROCPRIM_UNROLL - for(unsigned int i = 0; i < ::rocprim::arch::wavefront::min_size(); ++i) + for(unsigned int i = 0; i < ::rocprim::arch::wavefront::size(); ++i) { prefix = warp_move_dpp(prefix); prefix = scan_op(prefix, block_prefix); @@ -150,12 +154,12 @@ ROCPRIM_DEVICE ROCPRIM_INLINE T lookback_reduce_forward(F scan_op, T prefix, T b // If we can't rotate or shift the entire wavefront in one instruction, // iterate over rows of 16 lanes and use warp_readlane to communicate across rows. constexpr const int row_size = 16; - ROCPRIM_UNROLL - for(int j = ::rocprim::arch::wavefront::min_size(); j > 0; j -= row_size) + + for(int j = ::rocprim::arch::wavefront::size(); j > 0; j -= row_size) { prefix = warp_readlane( prefix, - j /* automatically taken modulo ::rocprim::arch::wavefront::min_size(), first read is lane 0 */); + j /* automatically taken modulo ::rocprim::arch::wavefront::size(), first read is lane 0 */); prefix = scan_op(prefix, block_prefix); ROCPRIM_UNROLL @@ -167,10 +171,9 @@ ROCPRIM_DEVICE ROCPRIM_INLINE T lookback_reduce_forward(F scan_op, T prefix, T b } #else // If no DPP available at all, fall back to shuffles. - ROCPRIM_UNROLL - for(unsigned int i = 0; i < ::rocprim::arch::wavefront::min_size(); ++i) + for(unsigned int i = 0; i < ::rocprim::arch::wavefront::size(); ++i) { - prefix = warp_shuffle(prefix, lane_id() + 1); + prefix = warp_shuffle(prefix, lane_id() + 1, ::rocprim::arch::wavefront::size()); prefix = scan_op(prefix, block_prefix); } #endif @@ -189,7 +192,7 @@ struct lookback_scan_state // Helper struct struct prefix_type { - T value; + T value; lookback_scan_prefix_flag flag; }; @@ -201,7 +204,15 @@ struct lookback_scan_state static constexpr bool use_sleep = UseSleep; - // temp_storage must point to allocation of get_storage_size(number_of_blocks) bytes + /// \brief Initializes the lookback_scan_state with the given temporary storage and the given grid size. + /// + /// \param [in,out] state the lookback_scan_state object to be initialized. + /// \param [in] temp_storage the temporary storage necessary for the calculation. Its size can be queried with the get_storage_size function. + /// \param [in] number_of_blocks the grid size for the kernel operation. + /// \param [in] stream the stream which will run the kernel. + /// + /// \returns \p hipSuccess (\p 0) after successful scan; otherwise a HIP runtime error of + /// type \p hipError_t. ROCPRIM_HOST_DEVICE static inline hipError_t create(lookback_scan_state& state, void* temp_storage, @@ -213,6 +224,17 @@ struct lookback_scan_state return hipSuccess; } + /// \brief This function queries the size of the temporary storage for the lookback scan algorithm. + /// + /// \par Overview + /// The lookback_scan needs a certain amount of temporary storage for the calculation. This function calculates the necessary size of the storage. + /// + /// \param [in] number_of_blocks the grid size for the kernel operation. + /// \param [in] stream the stream which will run the kernel. + /// \param [out] storage_size this parameter will contain the storage size in bytes. + /// + /// \returns \p hipSuccess (\p 0) after successful scan; otherwise a HIP runtime error of + /// type \p hipError_t. ROCPRIM_HOST_DEVICE static inline hipError_t get_storage_size(const unsigned int number_of_blocks, const hipStream_t stream, @@ -226,6 +248,17 @@ struct lookback_scan_state return error; } + /// \brief This function queries the layout of the temporary storage for the lookback scan algorithm. + /// + /// \par Overview + /// The lookback_scan needs a certain amount of temporary storage for the calculation. This function queries the layout of the storage. + /// + /// \param [in] number_of_blocks the grid size for the kernel operation. + /// \param [in] stream the stream which will run the kernel. + /// \param [out] layout this parameter will contain the storage layout. + /// + /// \returns \p hipSuccess (\p 0) after successful scan; otherwise a HIP runtime error of + /// type \p hipError_t. ROCPRIM_HOST_DEVICE static inline hipError_t get_temp_storage_layout(const unsigned int number_of_blocks, const hipStream_t stream, @@ -237,10 +270,14 @@ struct lookback_scan_state return error; } - ROCPRIM_DEVICE ROCPRIM_INLINE void initialize_prefix(const unsigned int block_id, - const unsigned int number_of_blocks) + /// \brief This device function initializes the prefixes of the lookback_scan_state instance. + /// + /// \param [in] block_id the prefixes are initialized per block. + /// \param [in] number_of_blocks grid size. + ROCPRIM_DEVICE ROCPRIM_INLINE + void initialize_prefix(const unsigned int block_id, const unsigned int number_of_blocks) { - constexpr unsigned int padding = ::rocprim::arch::wavefront::min_size(); + const unsigned int padding = ::rocprim::arch::wavefront::size(); if(block_id < number_of_blocks) { @@ -260,21 +297,35 @@ struct lookback_scan_state } } - ROCPRIM_DEVICE ROCPRIM_INLINE void set_partial(const unsigned int block_id, const T value) + /// \brief This device function sets the given prefix to the given value and to partial flag. + /// + /// \param [in] block_id the index of the prefix to be updated. + /// \param [in] value the value to update the prefix to. + ROCPRIM_DEVICE ROCPRIM_INLINE + void set_partial(const unsigned int block_id, const T value) { this->set(block_id, lookback_scan_prefix_flag::partial, value); } - ROCPRIM_DEVICE ROCPRIM_INLINE void set_complete(const unsigned int block_id, const T value) + /// \brief This device function sets the given prefix to the given value and to complete flag. + /// + /// \param [in] block_id the index of the prefix to be updated. + /// \param [in] value the value to update the prefix to. + ROCPRIM_DEVICE ROCPRIM_INLINE + void set_complete(const unsigned int block_id, const T value) { this->set(block_id, lookback_scan_prefix_flag::complete, value); } - // block_id must be > 0 + /// \brief This device function queries the value and the flag of the given prefix. + /// + /// \param [in] block_id the index of the prefix to be queried. + /// \param [out] flag the flag of the prefix. + /// \param [out] value the value of the prefix. ROCPRIM_DEVICE ROCPRIM_INLINE void get(const unsigned int block_id, lookback_scan_prefix_flag& flag, T& value) { - constexpr unsigned int padding = ::rocprim::arch::wavefront::min_size(); + const unsigned int padding = ::rocprim::arch::wavefront::size(); prefix_type prefix; @@ -285,7 +336,7 @@ struct lookback_scan_state memcpy(&prefix, &p, sizeof(prefix_type)); while(prefix.flag == lookback_scan_prefix_flag::empty) { - if ROCPRIM_IF_CONSTEXPR(UseSleep) + if constexpr(UseSleep) { for(unsigned int j = 0; j < times_through; j++) __builtin_amdgcn_s_sleep(1); @@ -302,11 +353,15 @@ struct lookback_scan_state value = prefix.value; } - /// \brief Gets the prefix value for a block. Should only be called after all - /// blocks/prefixes are completed. - ROCPRIM_DEVICE ROCPRIM_INLINE T get_complete_value(const unsigned int block_id) + /// \brief This device function queries the value of the given prefix. It should only be called after all the blocks/prefixes are complete. + /// + /// \param [in] block_id the index of the prefix to be queried. + /// + /// \returns the value of the prefix specified by the block_id. + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_complete_value(const unsigned int block_id) { - constexpr unsigned int padding = ::rocprim::arch::wavefront::min_size(); + const unsigned int padding = ::rocprim::arch::wavefront::size(); auto p = prefixes[padding + block_id]; prefix_type prefix{}; @@ -314,14 +369,23 @@ struct lookback_scan_state return prefix.value; } + /// \brief This device function calculates the prefix for the next block, based on this block. + /// + /// \tparam F [optional] The type of the scan_op parameter. + /// + /// \param [in] scan_op the scan operation used. + /// \param [in] block_id the index of the prefix to be processed. + /// + /// \returns the value of the prefix specified by the block_id. template - ROCPRIM_DEVICE ROCPRIM_INLINE T get_prefix_forward(F scan_op, unsigned int block_id_) + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_prefix_forward(F scan_op, unsigned int block_id_) { unsigned int lookback_block_id = block_id_ - lane_id() - 1; // There is one lookback scan per block, though a lookback scan is done by a single warp. // Because every lane of the warp checks a different lookback scan state value, - // we need space for at least ceil(CUs / arch::wavefront::min_size()) items in the cache, + // we need space for at least ceil(CUs / arch::wavefront::size()) items in the cache, // assuming that only one block is active per CU (assumes low occupancy). // For MI300, with 304 CUs, we have 304 / 64 = 5 items for the lookback cache. // Note that one item is kept in the `block_prefix` register, so we only need to @@ -332,7 +396,7 @@ struct lookback_scan_state int cache_offset = 0; lookback_scan_prefix_flag flag; - T block_prefix; + T block_prefix; this->get(lookback_block_id, flag, block_prefix); while(warp_all(flag != lookback_scan_prefix_flag::complete @@ -340,7 +404,7 @@ struct lookback_scan_state && cache_offset < max_lookback_per_thread) { cache[cache_offset++] = block_prefix; - lookback_block_id -= arch::wavefront::min_size(); + lookback_block_id -= arch::wavefront::size(); this->get(lookback_block_id, flag, block_prefix); } @@ -357,7 +421,7 @@ struct lookback_scan_state // All invalid, so we have to move one block back to // get back to known civilization. // Don't forget to pop one item off the cache too. - lookback_block_id += arch::wavefront::min_size(); + lookback_block_id += arch::wavefront::size(); --cache_offset; } @@ -393,7 +457,7 @@ struct lookback_scan_state ROCPRIM_DEVICE ROCPRIM_INLINE void set(const unsigned int block_id, const lookback_scan_prefix_flag flag, const T value) { - constexpr unsigned int padding = ::rocprim::arch::wavefront::min_size(); + const unsigned int padding = ::rocprim::arch::wavefront::size(); prefix_type prefix = {value, flag}; prefix_underlying_type p; @@ -412,11 +476,19 @@ struct lookback_scan_state public: using flag_underlying_type = std::underlying_type_t; - using value_type = T; + using value_type = T; static constexpr bool use_sleep = UseSleep; - // temp_storage must point to allocation of get_storage_size(number_of_blocks) bytes + /// \brief Initializes the lookback_scan_state with the given temporary storage and the given grid size. + /// + /// \param [in,out] state the lookback_scan_state object to be initialized. + /// \param [in] temp_storage the temporary storage necessary for the calculation. Its size can be queried with the get_storage_size function. + /// \param [in] number_of_blocks the grid size for the kernel operation. + /// \param [in] stream the stream which will run the kernel. + /// + /// \returns \p hipSuccess (\p 0) after successful scan; otherwise a HIP runtime error of + /// type \p hipError_t. ROCPRIM_HOST_DEVICE static inline hipError_t create(lookback_scan_state& state, void* temp_storage, @@ -441,6 +513,17 @@ struct lookback_scan_state return error; } + /// \brief This function queries the size of the temporary storage for the lookback scan algorithm. + /// + /// \par Overview + /// The lookback_scan needs a certain amount of temporary storage for the calculation. This function calculates the necessary size of the storage. + /// + /// \param [in] number_of_blocks the grid size for the kernel operation. + /// \param [in] stream the stream which will run the kernel. + /// \param [out] storage_size this parameter will contain the storage size in bytes. + /// + /// \returns \p hipSuccess (\p 0) after successful scan; otherwise a HIP runtime error of + /// type \p hipError_t. ROCPRIM_HOST_DEVICE static inline hipError_t get_storage_size(const unsigned int number_of_blocks, const hipStream_t stream, @@ -456,6 +539,17 @@ struct lookback_scan_state return error; } + /// \brief This function queries the layout of the temporary storage for the lookback scan algorithm. + /// + /// \par Overview + /// The lookback_scan needs a certain amount of temporary storage for the calculation. This function queries the layout of the storage. + /// + /// \param [in] number_of_blocks the grid size for the kernel operation. + /// \param [in] stream the stream which will run the kernel. + /// \param [out] layout this parameter will contain the storage layout. + /// + /// \returns \p hipSuccess (\p 0) after successful scan; otherwise a HIP runtime error of + /// type \p hipError_t. ROCPRIM_HOST_DEVICE static inline hipError_t get_temp_storage_layout(const unsigned int number_of_blocks, const hipStream_t stream, @@ -469,10 +563,14 @@ struct lookback_scan_state return error; } - ROCPRIM_DEVICE ROCPRIM_INLINE void initialize_prefix(const unsigned int block_id, - const unsigned int number_of_blocks) + /// \brief This device function initializes the prefixes of the lookback_scan_state instance. + /// + /// \param [in] block_id the prefixes are initialized per block. + /// \param [in] number_of_blocks grid size. + ROCPRIM_DEVICE ROCPRIM_INLINE + void initialize_prefix(const unsigned int block_id, const unsigned int number_of_blocks) { - constexpr unsigned int padding = ::rocprim::arch::wavefront::min_size(); + const unsigned int padding = ::rocprim::arch::wavefront::size(); if(block_id < number_of_blocks) { prefixes_flags[padding + block_id] @@ -485,21 +583,35 @@ struct lookback_scan_state } } - ROCPRIM_DEVICE ROCPRIM_INLINE void set_partial(const unsigned int block_id, const T value) + /// \brief Set the given prefix to the given value and to partial flag. + /// + /// \param [in] block_id the index of the prefix to be updated. + /// \param [in] value the value to update the prefix to. + ROCPRIM_DEVICE ROCPRIM_INLINE + void set_partial(const unsigned int block_id, const T value) { this->set(block_id, lookback_scan_prefix_flag::partial, value); } - ROCPRIM_DEVICE ROCPRIM_INLINE void set_complete(const unsigned int block_id, const T value) + /// \brief This device function sets the given prefix to the given value and to complete flag. + /// + /// \param [in] block_id the index of the prefix to be updated. + /// \param [in] value the value to update the prefix to. + ROCPRIM_DEVICE ROCPRIM_INLINE + void set_complete(const unsigned int block_id, const T value) { this->set(block_id, lookback_scan_prefix_flag::complete, value); } - // block_id must be > 0 + /// \brief This device function queries the value and the flag of the given prefix. + /// + /// \param [in] block_id the index of the prefix to be queried. + /// \param [out] flag the flag of the prefix. + /// \param [out] value the value of the prefix. ROCPRIM_DEVICE ROCPRIM_INLINE void get(const unsigned int block_id, lookback_scan_prefix_flag& flag, T& value) { - constexpr unsigned int padding = ::rocprim::arch::wavefront::min_size(); + const unsigned int padding = ::rocprim::arch::wavefront::size(); flag = this->get_flag(block_id); #if ROCPRIM_DETAIL_LOOKBACK_SCAN_STATE_WITHOUT_SLOW_FENCES @@ -520,15 +632,19 @@ struct lookback_scan_state const auto* values = static_cast(flag == lookback_scan_prefix_flag::partial ? prefixes_partial_values : prefixes_complete_values); - value = values[padding + block_id]; + value = values[padding + block_id]; #endif } - /// \brief Gets the prefix value for a block. Should only be called after all - /// blocks/prefixes are completed. - ROCPRIM_DEVICE ROCPRIM_INLINE T get_complete_value(const unsigned int block_id) + /// \brief This device function queries the value of the given prefix. It should only be called after all the blocks/prefixes are complete. + /// + /// \param [in] block_id the index of the prefix to be queried. + /// + /// \returns the value of the prefix specified by the block_id. + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_complete_value(const unsigned int block_id) { - constexpr unsigned int padding = ::rocprim::arch::wavefront::min_size(); + const unsigned int padding = ::rocprim::arch::wavefront::size(); #if ROCPRIM_DETAIL_LOOKBACK_SCAN_STATE_WITHOUT_SLOW_FENCES T value; @@ -546,9 +662,10 @@ struct lookback_scan_state #endif } - ROCPRIM_DEVICE ROCPRIM_INLINE T get_partial_value(const unsigned int block_id) + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_partial_value(const unsigned int block_id) { - constexpr unsigned int padding = ::rocprim::arch::wavefront::min_size(); + const unsigned int padding = ::rocprim::arch::wavefront::size(); #if ROCPRIM_DETAIL_LOOKBACK_SCAN_STATE_WITHOUT_SLOW_FENCES T value; @@ -566,8 +683,17 @@ struct lookback_scan_state #endif } + /// \brief This device function calculates the prefix for the next block, based on this block. + /// + /// \tparam F [optional] The type of the scan_op parameter. + /// + /// \param [in] scan_op the scan operation used. + /// \param [in] block_id the index of the prefix to be processed. + /// + /// \returns the value of the prefix specified by the block_id. template - ROCPRIM_DEVICE ROCPRIM_INLINE T get_prefix_forward(F scan_op, unsigned int block_id_) + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_prefix_forward(F scan_op, unsigned int block_id_) { unsigned int lookback_block_id = block_id_ - lane_id() - 1; @@ -579,7 +705,7 @@ struct lookback_scan_state && flag != lookback_scan_prefix_flag::invalid)) { ++cache_offset; - lookback_block_id -= arch::wavefront::min_size(); + lookback_block_id -= arch::wavefront::size(); flag = this->get_flag(lookback_block_id); } @@ -596,7 +722,7 @@ struct lookback_scan_state // All invalid, so we have to move one block back to // get back to known civilization. // Don't forget to pop one item off the cache too. - lookback_block_id += arch::wavefront::min_size(); + lookback_block_id += arch::wavefront::size(); --cache_offset; } @@ -624,7 +750,7 @@ struct lookback_scan_state // These are all guaranteed to be PARTIAL while(cache_offset > 0) { - lookback_block_id += arch::wavefront::min_size(); + lookback_block_id += arch::wavefront::size(); --cache_offset; block_prefix = this->get_partial_value(lookback_block_id); prefix = lookback_reduce_forward(scan_op, prefix, block_prefix); @@ -637,7 +763,7 @@ struct lookback_scan_state ROCPRIM_DEVICE ROCPRIM_INLINE lookback_scan_prefix_flag get_flag(const unsigned int block_id) { - constexpr unsigned int padding = ::rocprim::arch::wavefront::min_size(); + const unsigned int padding = ::rocprim::arch::wavefront::size(); const unsigned int SLEEP_MAX = 32; unsigned int times_through = 1; @@ -663,7 +789,7 @@ struct lookback_scan_state ROCPRIM_DEVICE ROCPRIM_INLINE void set(const unsigned int block_id, const lookback_scan_prefix_flag flag, const T value) { - constexpr unsigned int padding = ::rocprim::arch::wavefront::min_size(); + const unsigned int padding = ::rocprim::arch::wavefront::size(); #if ROCPRIM_DETAIL_LOOKBACK_SCAN_STATE_WITHOUT_SLOW_FENCES auto* values = static_cast( @@ -699,15 +825,18 @@ struct lookback_scan_state // We need to separate arrays for partial and final prefixes, because // value can be overwritten before flag is changed (flag and value are // not stored in single instruction). - void* prefixes_partial_values; - void* prefixes_complete_values; + void* prefixes_partial_values; + void* prefixes_complete_values; flag_underlying_type* prefixes_flags; }; template + lookback_scan_determinism Determinism = lookback_scan_determinism::default_determinism, + ::rocprim::arch::wavefront::target TargetWaveSize + = ::rocprim::arch::wavefront::get_target(), + typename Enabled = void> class lookback_scan_prefix_op { static_assert(std::is_same::value, @@ -715,8 +844,8 @@ class lookback_scan_prefix_op public: ROCPRIM_DEVICE ROCPRIM_INLINE lookback_scan_prefix_op(unsigned int block_id, - BinaryFunction scan_op, - LookbackScanState& scan_state) + BinaryFunction scan_op, + LookbackScanState& scan_state) : block_id_(block_id), scan_op_(scan_op), scan_state_(scan_state) {} @@ -733,7 +862,9 @@ class lookback_scan_prefix_op // from (block_id_ - 2) block etc. using headflag_scan_op_type = reverse_binary_op_wrapper; using warp_reduce_prefix_type - = warp_reduce_crosslane; + = warp_reduce_crosslane(), + false>; T block_prefix; scan_state_.get(block_id, flag, block_prefix); @@ -746,13 +877,14 @@ class lookback_scan_prefix_op headflag_scan_op); } - ROCPRIM_DEVICE ROCPRIM_INLINE T get_prefix() + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_prefix() { - if ROCPRIM_IF_CONSTEXPR(Determinism == lookback_scan_determinism::nondeterministic) + if constexpr(Determinism == lookback_scan_determinism::nondeterministic) { lookback_scan_prefix_flag flag; - T partial_prefix; - unsigned int previous_block_id = block_id_ - ::rocprim::lane_id() - 1; + T partial_prefix; + unsigned int previous_block_id = block_id_ - ::rocprim::lane_id() - 1; // reduce last warp_size() number of prefixes to // get the complete prefix for this block. @@ -762,7 +894,7 @@ class lookback_scan_prefix_op // while we don't load a complete prefix, reduce partial prefixes while(::rocprim::detail::warp_all(flag != lookback_scan_prefix_flag::complete)) { - previous_block_id -= ::rocprim::arch::wavefront::min_size(); + previous_block_id -= ::rocprim::arch::wavefront::size_from_target(); reduce_partial_prefixes(previous_block_id, flag, partial_prefix); prefix = scan_op_(partial_prefix, prefix); } @@ -775,7 +907,8 @@ class lookback_scan_prefix_op } public: - ROCPRIM_DEVICE ROCPRIM_INLINE T operator()(T reduction) + ROCPRIM_DEVICE ROCPRIM_INLINE + T operator()(T reduction) { // Set partial prefix for next block if(::rocprim::lane_id() == 0) @@ -800,6 +933,57 @@ class lookback_scan_prefix_op LookbackScanState& scan_state_; }; +template +class lookback_scan_prefix_op +{ +public: + ROCPRIM_DEVICE ROCPRIM_INLINE lookback_scan_prefix_op(unsigned int block_id, + BinaryFunction scan_op, + LookbackScanState& scan_state) + : wave32_op(block_id, scan_op, scan_state), wave64_op(block_id, scan_op, scan_state) + {} + + ROCPRIM_DEVICE ROCPRIM_INLINE ~lookback_scan_prefix_op() = default; + +private: + using lookback_scan_prefix_op_wave32 + = lookback_scan_prefix_op; + lookback_scan_prefix_op_wave32 wave32_op; + using lookback_scan_prefix_op_wave64 + = lookback_scan_prefix_op; + lookback_scan_prefix_op_wave64 wave64_op; + +public: + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto operator()(Args&&... args) + { + if(::rocprim::arch::wavefront::size() == ROCPRIM_WARP_SIZE_32) + { + return wave32_op(args...); + } + else + { + return wave64_op(args...); + } + } +}; + // This is a HOST only API // It is known that early revisions of MI100 (gfx908) hang in the wait loop of // lookback_scan_state::get() without sleeping (s_sleep). @@ -860,7 +1044,8 @@ class offset_lookback_scan_factory ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_POP template - static ROCPRIM_DEVICE auto create(PrefixOp& prefix_op, storage_type& storage) + static ROCPRIM_DEVICE + auto create(PrefixOp& prefix_op, storage_type& storage) { return [&](T reduction) mutable { @@ -874,12 +1059,14 @@ class offset_lookback_scan_factory }; } - static ROCPRIM_DEVICE T get_reduction(const storage_type& storage) + static ROCPRIM_DEVICE + T get_reduction(const storage_type& storage) { return storage.get().block_reduction; } - static ROCPRIM_DEVICE T get_prefix(const storage_type& storage) + static ROCPRIM_DEVICE + T get_prefix(const storage_type& storage) { return storage.get().prefix; } @@ -896,7 +1083,8 @@ class offset_lookback_scan_prefix_op using base_type = lookback_scan_prefix_op; using factory = detail::offset_lookback_scan_factory; - ROCPRIM_DEVICE ROCPRIM_INLINE base_type& base() + ROCPRIM_DEVICE ROCPRIM_INLINE + base_type& base() { return *this; } @@ -905,29 +1093,33 @@ class offset_lookback_scan_prefix_op using storage_type = typename factory::storage_type; ROCPRIM_DEVICE ROCPRIM_INLINE offset_lookback_scan_prefix_op(unsigned int block_id, - LookbackScanState& state, - storage_type& storage, - BinaryOp binary_op = BinaryOp()) + LookbackScanState& state, + storage_type& storage, + BinaryOp binary_op = BinaryOp()) : base_type(block_id, BinaryOp(std::move(binary_op)), state), storage(storage) {} - ROCPRIM_DEVICE ROCPRIM_INLINE T operator()(T reduction) + ROCPRIM_DEVICE ROCPRIM_INLINE + T operator()(T reduction) { return factory::create(base(), storage)(reduction); } - ROCPRIM_DEVICE ROCPRIM_INLINE T get_reduction() const + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_reduction() const { return factory::get_reduction(storage); } - ROCPRIM_DEVICE ROCPRIM_INLINE T get_prefix() const + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_prefix() const { return factory::get_prefix(storage); } // rocThrust uses this implementation detail of rocPRIM, required for backwards compatibility - ROCPRIM_DEVICE ROCPRIM_INLINE T get_exclusive_prefix() const + ROCPRIM_DEVICE ROCPRIM_INLINE + T get_exclusive_prefix() const { return get_prefix(); } diff --git a/rocprim/include/rocprim/device/device_adjacent_difference.hpp b/rocprim/include/rocprim/device/device_adjacent_difference.hpp index dc3a0696b..c81cd0335 100644 --- a/rocprim/include/rocprim/device/device_adjacent_difference.hpp +++ b/rocprim/include/rocprim/device/device_adjacent_difference.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -28,8 +28,8 @@ #include "config_types.hpp" #include "device_transform.hpp" -#include "../config.hpp" #include "../common.hpp" +#include "../config.hpp" #include "../functional.hpp" #include "../detail/temp_storage.hpp" @@ -81,12 +81,12 @@ void ROCPRIM_KERNEL ROCPRIM_LAUNCH_BOUNDS( starting_block); } -template +template hipError_t adjacent_difference_impl(void* const temporary_storage, std::size_t& storage_size, const InputIt input, @@ -97,7 +97,7 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, const bool debug_synchronous) { using value_type = typename std::iterator_traits::value_type; - using output_type = rocprim::invoke_result_binary_op_t; + using output_type = ::rocprim::accumulator_t; using larger_type = std::conditional_t<(sizeof(value_type) >= sizeof(output_type)), value_type, output_type>; @@ -137,7 +137,7 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, // Copy values before they are overwritten to use as tile predecessors/successors // previous_values is not dereferenced when the operation is not in place - if ROCPRIM_IF_CONSTEXPR(InPlace) + if constexpr(InPlace) { // If doing left adjacent diff then the last item of each block is needed for the // next block, otherwise the first item is needed for the previous block @@ -151,7 +151,7 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, const hipError_t error = ::rocprim::transform(block_starts_iter, previous_values, num_blocks - 1, - rocprim::identity<> {}, + rocprim::identity<>{}, stream, debug_synchronous); if(error != hipSuccess) @@ -206,15 +206,14 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, op, previous_values + starting_block, starting_block); - ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR( - "adjacent_difference_kernel", current_size, start); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("adjacent_difference_kernel", + current_size, + start); } return hipSuccess; } } // namespace detail - - #endif // DOXYGEN_SHOULD_SKIP_THIS /// \addtogroup devicemodule @@ -297,23 +296,29 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, /// // output: [8, 1, 1, 1, 1, 1, 1, 1] /// \endcode /// \endparblock -template > +template> hipError_t adjacent_difference(void* const temporary_storage, std::size_t& storage_size, const InputIt input, const OutputIt output, const std::size_t size, - const BinaryFunction op = BinaryFunction {}, + const BinaryFunction op = BinaryFunction{}, const hipStream_t stream = 0, const bool debug_synchronous = false) { static constexpr bool in_place = false; static constexpr bool right = false; - return detail::adjacent_difference_impl( - temporary_storage, storage_size, input, output, size, op, stream, debug_synchronous); + return detail::adjacent_difference_impl(temporary_storage, + storage_size, + input, + output, + size, + op, + stream, + debug_synchronous); } /// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements @@ -350,21 +355,27 @@ hipError_t adjacent_difference(void* const temporary_storage, /// /// \return `hipSuccess` (0) after successful scan, otherwise the HIP runtime error of /// type `hipError_t` -template > +template> hipError_t adjacent_difference_inplace(void* const temporary_storage, std::size_t& storage_size, const InputIt values, const std::size_t size, - const BinaryFunction op = BinaryFunction {}, + const BinaryFunction op = BinaryFunction{}, const hipStream_t stream = 0, const bool debug_synchronous = false) { static constexpr bool in_place = true; static constexpr bool right = false; - return detail::adjacent_difference_impl( - temporary_storage, storage_size, values, values, size, op, stream, debug_synchronous); + return detail::adjacent_difference_impl(temporary_storage, + storage_size, + values, + values, + size, + op, + stream, + debug_synchronous); } /// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements @@ -502,23 +513,29 @@ hipError_t adjacent_difference_inplace(void* const temporary_storage, /// // output: [1, 1, 1, 1, 1, 1, 1, 8] /// \endcode /// \endparblock -template > +template> hipError_t adjacent_difference_right(void* const temporary_storage, std::size_t& storage_size, const InputIt input, const OutputIt output, const std::size_t size, - const BinaryFunction op = BinaryFunction {}, + const BinaryFunction op = BinaryFunction{}, const hipStream_t stream = 0, const bool debug_synchronous = false) { static constexpr bool in_place = false; static constexpr bool right = true; - return detail::adjacent_difference_impl( - temporary_storage, storage_size, input, output, size, op, stream, debug_synchronous); + return detail::adjacent_difference_impl(temporary_storage, + storage_size, + input, + output, + size, + op, + stream, + debug_synchronous); } /// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements @@ -555,21 +572,27 @@ hipError_t adjacent_difference_right(void* const temporary_storage, /// /// \return `hipSuccess` (0) after successful scan, otherwise the HIP runtime error of /// type `hipError_t` -template > +template> hipError_t adjacent_difference_right_inplace(void* const temporary_storage, std::size_t& storage_size, const InputIt values, const std::size_t size, - const BinaryFunction op = BinaryFunction {}, + const BinaryFunction op = BinaryFunction{}, const hipStream_t stream = 0, const bool debug_synchronous = false) { static constexpr bool in_place = true; static constexpr bool right = true; - return detail::adjacent_difference_impl( - temporary_storage, storage_size, values, values, size, op, stream, debug_synchronous); + return detail::adjacent_difference_impl(temporary_storage, + storage_size, + values, + values, + size, + op, + stream, + debug_synchronous); } /// \brief Parallel primitive for applying a binary operation across pairs of consecutive elements diff --git a/rocprim/include/rocprim/device/device_binary_search.hpp b/rocprim/include/rocprim/device/device_binary_search.hpp index 31c73ab28..e4dd5f221 100644 --- a/rocprim/include/rocprim/device/device_binary_search.hpp +++ b/rocprim/include/rocprim/device/device_binary_search.hpp @@ -70,17 +70,14 @@ hipError_t binary_search(void * temporary_storage, return hipSuccess; } - return transform( - needles, output, + return detail::transform_impl( + needles, + output, needles_size, - [haystack, haystack_size, search_op, compare_op] - ROCPRIM_DEVICE - (const value_type& value) - { - return search_op(haystack, haystack_size, value, compare_op); - }, - stream, debug_synchronous - ); + [haystack, haystack_size, search_op, compare_op](const value_type& value) + { return search_op(haystack, haystack_size, value, compare_op); }, + stream, + debug_synchronous); } template diff --git a/rocprim/include/rocprim/device/device_binary_search_config.hpp b/rocprim/include/rocprim/device/device_binary_search_config.hpp index 7b4a968a0..3580830b8 100644 --- a/rocprim/include/rocprim/device/device_binary_search_config.hpp +++ b/rocprim/include/rocprim/device/device_binary_search_config.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -50,8 +50,8 @@ template struct default_config_for_lower_bound {}; -template -struct wrapped_transform_config, Unused> +template +struct wrapped_transform_config, Unused, IsPointer> { template struct architecture_config @@ -61,8 +61,8 @@ struct wrapped_transform_config, }; }; -template -struct wrapped_transform_config, Unused> +template +struct wrapped_transform_config, Unused, IsPointer> { template struct architecture_config @@ -72,8 +72,8 @@ struct wrapped_transform_config, U }; }; -template -struct wrapped_transform_config, Unused> +template +struct wrapped_transform_config, Unused, IsPointer> { template struct architecture_config @@ -84,21 +84,21 @@ struct wrapped_transform_config, U }; #ifndef DOXYGEN_SHOULD_SKIP_THIS -template +template template constexpr transform_config_params - wrapped_transform_config, - Unused>::architecture_config::params; -template + wrapped_transform_config, Unused, IsPointer>:: + architecture_config::params; +template template constexpr transform_config_params - wrapped_transform_config, - Unused>::architecture_config::params; -template + wrapped_transform_config, Unused, IsPointer>:: + architecture_config::params; +template template constexpr transform_config_params - wrapped_transform_config, - Unused>::architecture_config::params; + wrapped_transform_config, Unused, IsPointer>:: + architecture_config::params; #endif // DOXYGEN_SHOULD_SKIP_THIS } // end namespace detail diff --git a/rocprim/include/rocprim/device/device_merge.hpp b/rocprim/include/rocprim/device/device_merge.hpp index e8fc0d2a2..e6e037469 100644 --- a/rocprim/include/rocprim/device/device_merge.hpp +++ b/rocprim/include/rocprim/device/device_merge.hpp @@ -25,12 +25,12 @@ #include #include -#include "../config.hpp" #include "../common.hpp" +#include "../config.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" +#include "../detail/virtual_shared_memory.hpp" -#include "device_merge_config.hpp" #include "detail/device_merge.hpp" BEGIN_ROCPRIM_NAMESPACE @@ -80,20 +80,40 @@ ROCPRIM_KERNEL ROCPRIM_LAUNCH_BOUNDS(device_params().kernel_config.block ValuesOutputIterator values_output, const size_t input1_size, const size_t input2_size, - BinaryFunction compare_function) + BinaryFunction compare_function, + detail::vsmem_t vsmem) +{ + using key_type = typename std::iterator_traits::value_type; + using value_type = typename std::iterator_traits::value_type; + + using merge_kernel_impl_t = merge_kernel_impl_; + + using VSmemHelperT = detail::vsmem_helper_impl; + ROCPRIM_SHARED_MEMORY typename VSmemHelperT::static_temp_storage_t static_temp_storage; + // Get temporary storage + typename merge_kernel_impl_t::storage_type& storage + = VSmemHelperT::get_temp_storage(static_temp_storage, vsmem); + + merge_kernel_impl_t().merge(index, + keys_input1, + keys_input2, + keys_output, + values_input1, + values_input2, + values_output, + input1_size, + input2_size, + compare_function, + storage); +} + +template +inline size_t get_merge_vsmem_size_per_block() { - static constexpr merge_config_params params = device_params(); - merge_kernel_impl( - index, - keys_input1, - keys_input2, - keys_output, - values_input1, - values_input2, - values_output, - input1_size, - input2_size, - compare_function); + using merge_kernel_impl_t = merge_kernel_impl_; + using MergeVSmemHelperT = detail::vsmem_helper_impl; + + return MergeVSmemHelperT::vsmem_per_block; } template< @@ -138,29 +158,40 @@ hipError_t merge_impl(void * temporary_storage, const unsigned int block_size = params.kernel_config.block_size; const unsigned int half_block = block_size / 2; const unsigned int items_per_thread = params.kernel_config.items_per_thread; - const auto items_per_block = block_size * items_per_thread; + const unsigned int items_per_block = block_size * items_per_thread; - const unsigned int partitions + const unsigned int number_of_blocks = ((input1_size + input2_size) + items_per_block - 1) / items_per_block; - unsigned int* index; + size_t virtual_shared_memory_size + = get_merge_vsmem_size_per_block() * number_of_blocks; + + unsigned int* index = nullptr; + void* vsmem = nullptr; const hipError_t partition_result = detail::temp_storage::partition( temporary_storage, storage_size, - detail::temp_storage::ptr_aligned_array(&index, partitions + 1)); + detail::temp_storage::make_linear_partition( + detail::temp_storage::ptr_aligned_array(&index, number_of_blocks + 1), + // vsmem + detail::temp_storage::make_partition(&vsmem, + virtual_shared_memory_size, + cache_line_size))); + if(partition_result != hipSuccess || temporary_storage == nullptr) { return partition_result; } - if( partitions == 0u ) + if(number_of_blocks == 0u) + { return hipSuccess; + } // Start point for time measurements std::chrono::steady_clock::time_point start; - auto number_of_blocks = partitions; if(debug_synchronous) { std::cout << "block_size " << block_size << '\n'; @@ -168,39 +199,32 @@ hipError_t merge_impl(void * temporary_storage, std::cout << "items_per_block " << items_per_block << '\n'; } - const unsigned partition_blocks = ((partitions + 1) + half_block - 1) / half_block; + const unsigned int partition_blocks = ((number_of_blocks + 1) + half_block - 1) / half_block; if(debug_synchronous) start = std::chrono::steady_clock::now(); - hipLaunchKernelGGL(HIP_KERNEL_NAME(detail::partition_kernel), - dim3(partition_blocks), - dim3(half_block), - 0, - stream, - index, - keys_input1, - keys_input2, - input1_size, - input2_size, - items_per_block, - compare_function); + detail::partition_kernel + <<>>(index, + keys_input1, + keys_input2, + input1_size, + input2_size, + items_per_block, + compare_function); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("partition_kernel", input1_size, start); if(debug_synchronous) start = std::chrono::steady_clock::now(); - hipLaunchKernelGGL(HIP_KERNEL_NAME(detail::merge_kernel), - dim3(number_of_blocks), - dim3(block_size), - 0, - stream, - index, - keys_input1, - keys_input2, - keys_output, - values_input1, - values_input2, - values_output, - input1_size, - input2_size, - compare_function); + detail::merge_kernel + <<>>(index, + keys_input1, + keys_input2, + keys_output, + values_input1, + values_input2, + values_output, + input1_size, + input2_size, + compare_function, + detail::vsmem_t{vsmem}); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("merge_kernel", input1_size, start); return hipSuccess; diff --git a/rocprim/include/rocprim/device/device_merge_inplace.hpp b/rocprim/include/rocprim/device/device_merge_inplace.hpp index e82ec1cf6..c1d08a066 100644 --- a/rocprim/include/rocprim/device/device_merge_inplace.hpp +++ b/rocprim/include/rocprim/device/device_merge_inplace.hpp @@ -36,7 +36,6 @@ #include "../intrinsics/thread.hpp" #include "../thread/thread_search.hpp" -#include #include #include @@ -652,7 +651,7 @@ struct merge_inplace_impl /// signature of the function should be equivalent to the following: `bool f(const T &a, const T &b);`. /// The signature does not need to have `const &`, but the function object must not modify /// the objects passed to it. The default value is `BinaryFunction()`. -/// \param [in] stream The HIP stream object. Default is `0` (`hipDefaultStream`). +/// \param [in] stream The HIP stream object. Default is `0` (`hipStreamDefault`). /// \param [in] debug_synchronous If `true`, forces a device synchronization after every kernel /// launch in order to check for errors. Default value is `false`. /// diff --git a/rocprim/include/rocprim/device/device_partial_sort.hpp b/rocprim/include/rocprim/device/device_partial_sort.hpp index 1ce720ab7..1653e0c06 100644 --- a/rocprim/include/rocprim/device/device_partial_sort.hpp +++ b/rocprim/include/rocprim/device/device_partial_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -59,7 +59,8 @@ struct radix_sort_condition_checker static constexpr bool descending = std::is_same>::value; static constexpr bool ascending = std::is_same>::value; - static constexpr bool is_radix_key_fundamental = detail::radix_key_fundamental::value; + static constexpr bool is_radix_key_fundamental + = rocprim::traits::radix_key_codec::radix_key_fundamental::value; static constexpr bool use_radix_sort = (is_radix_key_fundamental || is_custom_decomposer) && (descending || ascending); }; diff --git a/rocprim/include/rocprim/device/device_reduce.hpp b/rocprim/include/rocprim/device/device_reduce.hpp index 65e17563c..a94cb31b7 100644 --- a/rocprim/include/rocprim/device/device_reduce.hpp +++ b/rocprim/include/rocprim/device/device_reduce.hpp @@ -119,8 +119,7 @@ hipError_t reduce_impl(void * temporary_storage, bool debug_synchronous) { using input_type = typename std::iterator_traits::value_type; - using result_type = - typename ::rocprim::invoke_result_binary_op::type; + using result_type = ::rocprim::accumulator_t; using config = wrapped_reduce_config; diff --git a/rocprim/include/rocprim/device/device_scan.hpp b/rocprim/include/rocprim/device/device_scan.hpp index 575675e05..d70a127eb 100644 --- a/rocprim/include/rocprim/device/device_scan.hpp +++ b/rocprim/include/rocprim/device/device_scan.hpp @@ -25,8 +25,8 @@ #include #include -#include "../config.hpp" #include "../common.hpp" +#include "../config.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" #include "../functional.hpp" @@ -48,6 +48,7 @@ namespace detail // Single kernel scan (performs scan on one thread block only) template(values, // input - values, // output - initial_value, - storage.scan, - scan_op); + single_scan_block_scan(values, // input + values, // output + initial_value, + storage.scan, + scan_op); ::rocprim::syncthreads(); // sync threads to reuse shared memory // Save values into output array @@ -94,6 +95,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void single_scan_kernel_impl(InputIterator } template().kernel_config.block OutputIterator output, BinaryFunction scan_op) { - single_scan_kernel_impl(input, - size, - static_cast(get_input_value(initial_value)), - output, - scan_op); + single_scan_kernel_impl( + input, + size, + static_cast(get_input_value(initial_value)), + output, + scan_op); } // Single pass (look-back kernels) template -ROCPRIM_KERNEL - ROCPRIM_LAUNCH_BOUNDS(device_params().kernel_config.block_size) void - lookback_scan_kernel(InputIterator input, - OutputIterator output, - const size_t size, - const InitValueType initial_value, - BinaryFunction scan_op, - LookBackScanState lookback_scan_state, - const unsigned int number_of_blocks, - AccType* previous_last_element = nullptr, - AccType* new_last_element = nullptr, - bool override_first_value = false, - bool save_last_value = false) +ROCPRIM_KERNEL ROCPRIM_LAUNCH_BOUNDS(device_params().kernel_config.block_size) void + lookback_scan_kernel(InputIterator input, + OutputIterator output, + const size_t size, + InitValueType initial_value, + BinaryFunction scan_op, + LookBackScanState lookback_scan_state, + const unsigned int number_of_blocks, + AccType* previous_last_element = nullptr, + AccType* new_last_element = nullptr, + bool override_first_value = false, + bool save_last_value = false) { - lookback_scan_kernel_impl( + lookback_scan_kernel_impl( input, output, size, @@ -155,6 +158,7 @@ ROCPRIM_KERNEL template(size_limit - size_limit % items_per_block, items_per_block); - size_t limited_size = std::min(size, aligned_size_limit); + size_t limited_size = std::min(size, aligned_size_limit); const bool use_limited_size = limited_size == aligned_size_limit; - unsigned int number_of_blocks = (limited_size + items_per_block - 1)/items_per_block; + unsigned int number_of_blocks = (limited_size + items_per_block - 1) / items_per_block; // Pointer to array with block_prefixes void* scan_state_storage; @@ -226,7 +230,7 @@ inline auto scan_impl(void* temporary_storage, // Start point for time measurements std::chrono::steady_clock::time_point start; - if( number_of_blocks == 0u ) + if(number_of_blocks == 0u) return hipSuccess; if(number_of_blocks > 1 || use_limited_size) @@ -266,14 +270,15 @@ inline auto scan_impl(void* temporary_storage, } }; - if(debug_synchronous) start = std::chrono::steady_clock::now(); + if(debug_synchronous) + start = std::chrono::steady_clock::now(); - size_t number_of_launch = (size + limited_size - 1)/limited_size; - for (size_t i = 0, offset = 0; i < number_of_launch; i++, offset+=limited_size ) + size_t number_of_launch = (size + limited_size - 1) / limited_size; + for(size_t i = 0, offset = 0; i < number_of_launch; i++, offset += limited_size) { size_t current_size = std::min(size - offset, limited_size); - number_of_blocks = (current_size + items_per_block - 1)/items_per_block; - auto grid_size = (number_of_blocks + block_size - 1)/block_size; + number_of_blocks = (current_size + items_per_block - 1) / items_per_block; + auto grid_size = (number_of_blocks + block_size - 1) / block_size; if(debug_synchronous) { @@ -299,7 +304,8 @@ inline auto scan_impl(void* temporary_storage, number_of_blocks, start); - if(debug_synchronous) start = std::chrono::steady_clock::now(); + if(debug_synchronous) + start = std::chrono::steady_clock::now(); grid_size = number_of_blocks; if(debug_synchronous) @@ -317,6 +323,7 @@ inline auto scan_impl(void* temporary_storage, { lookback_scan_kernel 1); }); - ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("lookback_scan_kernel", current_size, start); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("lookback_scan_kernel", + current_size, + start); // Swap the last_elements if(number_of_launch > 1) @@ -346,7 +355,8 @@ inline auto scan_impl(void* temporary_storage, ::rocprim::identity(), stream, debug_synchronous); - if(error != hipSuccess) return error; + if(error != hipSuccess) + return error; } } } @@ -362,6 +372,7 @@ inline auto scan_impl(void* temporary_storage, } single_scan_kernel, where \p T is a \p value_type of \p InputIterator. /// \tparam AccType accumulator type used to propagate the scanned values. The default is the type that -/// is returned by a function of type BinaryFunction when it's is passed an InputIterator value. +/// is returned by a function of type BinaryFunction when it is passed an InputIterator value. /// /// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to @@ -495,7 +506,9 @@ template::value_type>, - class AccType = rocprim::invoke_result_binary_op_t::value_type, BinaryFunction>> + class AccType + = ::rocprim::accumulator_t::value_type>> inline hipError_t inclusive_scan(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -507,6 +520,7 @@ inline hipError_t inclusive_scan(void* temporary_storage, { // input_type() is a dummy initial value (not used) return detail::scan_impl, where \p T is a \p value_type of \p InputIterator. +/// \tparam AccType accumulator type used to propagate the scanned values. Default type +/// is value type of the input iterator. +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the scan operation. +/// \param [in,out] storage_size reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input iterator to the first element in the range to scan. +/// \param [out] output iterator to the first element in the output range. It can be +/// same as \p input. +/// \param [in] initial_value initial value to start the scan. +/// A rocpim::future_value may be passed to use a value that will be later computed. +/// \param [in] size number of element in the input range. +/// \param [in] scan_op binary operation function object that will be used for scan. +/// The signature of the function should be equivalent to the following: +/// T f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// Default is BinaryFunction(). +/// \param [in] stream [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful scan; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level inclusive sum operation is performed on an array of +/// integer values (shorts are scanned into ints). +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// short * input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int * output; // empty array of 8 elements +/// int initial_value; // e.g. 10 +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::inclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, initial_value, input_size, rocprim::plus() +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform scan +/// rocprim::inclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, initial_value, input_size, rocprim::plus() +/// ); +/// // output: [11, 13, 16, 20, 25, 31, 38, 46] +/// \endcode +/// +/// The same example as above, but now a custom accumulator type is specified. +/// +/// \code{.cpp} +/// #include +/// +/// size_t input_size; +/// short * input; +/// int * output; +/// int initial_value; +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// +/// rocprim::inclusive_scan( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, output, initial_value, input_size, rocprim::plus() +/// ); +/// +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // Use type parameter to set custom accumulator type +/// rocprim::inclusive_scan, +/// int>(temporary_storage_ptr, +/// temporary_storage_size_bytes, +/// input_iterator, +/// output, +/// initial_value +/// input_size, +/// rocprim::plus()); +/// \endcode +/// \endparblock +template::value_type>, + class AccType + = ::rocprim::accumulator_t::value_type>> +inline hipError_t inclusive_scan(void* temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + const InitValueType initial_value, + const size_t size, + BinaryFunction scan_op = BinaryFunction(), + const hipStream_t stream = 0, + bool debug_synchronous = false) +{ + // input_type() is a dummy initial value (not used) + return detail::scan_impl(temporary_storage, + storage_size, + input, + output, + initial_value, + size, + scan_op, + stream, + debug_synchronous); +} + /// \brief Bitwise-reproducible parallel inclusive scan primitive for device level. /// /// This function behaves the same as inclusive_scan(), except that unlike @@ -536,7 +709,9 @@ template::value_type>, - class AccType = rocprim::invoke_result_binary_op_t::value_type, BinaryFunction>> + class AccType + = ::rocprim::accumulator_t::value_type>> inline hipError_t deterministic_inclusive_scan(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -547,6 +722,7 @@ inline hipError_t deterministic_inclusive_scan(void* temporary_stora bool debug_synchronous = false) { return detail::scan_implinclusive_scan(), except that unlike +/// inclusive_scan(), it provides run-to-run deterministic behavior for +/// non-associative scan operators like floating point arithmetic operations. +/// Refer to the documentation for \link inclusive_scan() rocprim::inclusive_scan \endlink +/// for a detailed description of this function. +template::value_type>, + class AccType + = ::rocprim::accumulator_t::value_type>> +inline hipError_t deterministic_inclusive_scan(void* temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + InitValueType initial_value, + const size_t size, + BinaryFunction scan_op = BinaryFunction(), + const hipStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::scan_impl(temporary_storage, + storage_size, + input, + output, + initial_value, + size, + scan_op, + stream, + debug_synchronous); +} + /// \brief Parallel exclusive scan primitive for device level. /// /// exclusive_scan function performs a device-wide exclusive prefix scan operation @@ -591,7 +812,8 @@ inline hipError_t deterministic_inclusive_scan(void* temporary_stora /// \tparam BinaryFunction type of binary function used for scan. Default type /// is \p rocprim::plus, where \p T is a \p value_type of \p InputIterator. /// \tparam AccType accumulator type used to propagate the scanned values. The default is the type that -/// is returned by a function of type BinaryFunction when it's is passed a value of type InitValueType. +/// is returned by a function of type BinaryFunction when it is passed a value of type \p InitValueType, +/// unless it's 'rocprim::future_value'. Then it will be the wrapped input type. /// /// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to @@ -661,7 +883,8 @@ template::value_type>, - class AccType = rocprim::invoke_result_binary_op_t, BinaryFunction>> + class AccType + = ::rocprim::accumulator_t>> inline hipError_t exclusive_scan(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -673,6 +896,7 @@ inline hipError_t exclusive_scan(void* temporary_storage, bool debug_synchronous = false) { return detail::scan_impl::value_type>, - class AccType = rocprim::invoke_result_binary_op_t, BinaryFunction>> + class AccType + = ::rocprim::accumulator_t>> inline hipError_t deterministic_exclusive_scan(void* temporary_storage, size_t& storage_size, InputIterator input, @@ -715,6 +940,7 @@ inline hipError_t deterministic_exclusive_scan(void* temporary_sto bool debug_synchronous = false) { return detail::scan_impl -struct wrapped_search_n_config +template +struct wrapped_search_n_config { template struct architecture_config { - static constexpr search_n_config_params params = {8, kernel_config<256, 4>()}; + static constexpr search_n_config_params params + = default_search_n_config(Arch), Value>{}; }; }; diff --git a/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp index 4b385e5c5..4a98e4708 100644 --- a/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp @@ -39,7 +39,6 @@ #include "../block/block_load.hpp" #include "../iterator/counting_iterator.hpp" #include "../iterator/reverse_iterator.hpp" -#include "../thread/radix_key_codec.hpp" #include "detail/device_segmented_radix_sort.hpp" #include "device_partition.hpp" #include "device_segmented_radix_sort_config.hpp" @@ -348,12 +347,12 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, return segment_length > max_small_segment_length; }; - const bool with_double_buffer = keys_tmp != nullptr; - const unsigned int bits = end_bit - begin_bit; - const unsigned int iterations = ::rocprim::detail::ceiling_div(bits, params.long_radix_bits); - const bool to_output = with_double_buffer || (iterations - 1) % 2 == 0; - is_result_in_output = (iterations % 2 == 0) != to_output; - const bool do_partitioning + const bool with_double_buffer = keys_tmp != nullptr; + const unsigned int bits = end_bit - begin_bit; + const unsigned int iterations = ::rocprim::detail::ceiling_div(bits, params.radix_bits); + const bool to_output = with_double_buffer || (iterations - 1) % 2 == 0; + is_result_in_output = (iterations % 2 == 0) != to_output; + const bool do_partitioning = partitioning_allowed && segments >= params.warp_sort_config.partitioning_threshold; const size_t medium_segment_indices_size = three_way_partitioning ? segments : 0; diff --git a/rocprim/include/rocprim/device/device_segmented_reduce.hpp b/rocprim/include/rocprim/device/device_segmented_reduce.hpp index d748cdae8..2216f8a83 100644 --- a/rocprim/include/rocprim/device/device_segmented_reduce.hpp +++ b/rocprim/include/rocprim/device/device_segmented_reduce.hpp @@ -85,8 +85,7 @@ hipError_t segmented_reduce_impl(void * temporary_storage, bool debug_synchronous) { using input_type = typename std::iterator_traits::value_type; - using result_type = - typename ::rocprim::invoke_result_binary_op::type; + using result_type = ::rocprim::accumulator_t; using config = wrapped_segmented_reduce_config; diff --git a/rocprim/include/rocprim/device/device_transform.hpp b/rocprim/include/rocprim/device/device_transform.hpp index 4028dfc13..0ac0c286d 100644 --- a/rocprim/include/rocprim/device/device_transform.hpp +++ b/rocprim/include/rocprim/device/device_transform.hpp @@ -21,19 +21,20 @@ #ifndef ROCPRIM_DEVICE_DEVICE_TRANSFORM_HPP_ #define ROCPRIM_DEVICE_DEVICE_TRANSFORM_HPP_ -#include -#include -#include -#include - -#include "../config.hpp" #include "../common.hpp" +#include "../config.hpp" #include "../detail/various.hpp" #include "../iterator/zip_iterator.hpp" #include "../types/tuple.hpp" -#include "device_transform_config.hpp" #include "detail/device_transform.hpp" +#include "device_transform_config.hpp" + +#include +#include +#include +#include +#include /// \addtogroup devicemodule /// @{ @@ -43,7 +44,8 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template().kernel_config.block_size) void transform_kernel( InputIterator input, const size_t size, OutputIterator output, UnaryFunction transform_op) { - transform_kernel_impl().kernel_config.block_size, + transform_kernel_impl().kernel_config.block_size, device_params().kernel_config.items_per_thread, + device_params().load_type, ResultType>(input, size, output, transform_op); } -} // end of detail namespace +template +inline hipError_t transform_impl(InputIterator input, + OutputIterator output, + const size_t size, + UnaryFunction transform_op, + const hipStream_t stream, + bool debug_synchronous) +{ + if(size == size_t(0)) + { + return hipSuccess; + } + + using input_type = typename std::iterator_traits::value_type; + using result_type = typename ::rocprim::invoke_result::type; + + using config = detail::wrapped_transform_config; + + detail::target_arch target_arch; + hipError_t result = detail::host_target_arch(stream, target_arch); + if(result != hipSuccess) + { + return result; + } + const detail::transform_config_params params + = detail::dispatch_target_arch(target_arch); + + const unsigned int block_size = params.kernel_config.block_size; + const unsigned int items_per_thread = params.kernel_config.items_per_thread; + const auto items_per_block = block_size * items_per_thread; + + // Start point for time measurements + std::chrono::steady_clock::time_point start; + + const auto size_limit = params.kernel_config.size_limit; + const auto number_of_blocks_limit = ::rocprim::max(size_limit / items_per_block, 1); + + auto number_of_blocks = (size + items_per_block - 1) / items_per_block; + if(debug_synchronous) + { + std::cout << "block_size " << block_size << '\n'; + std::cout << "number of blocks " << number_of_blocks << '\n'; + std::cout << "number of blocks limit " << number_of_blocks_limit << '\n'; + std::cout << "items_per_block " << items_per_block << '\n'; + } + + const auto aligned_size_limit = number_of_blocks_limit * items_per_block; + + // Launch number_of_blocks_limit blocks while there is still at least as many blocks left as the limit + const auto number_of_launch = (size + aligned_size_limit - 1) / aligned_size_limit; + for(size_t i = 0, offset = 0; i < number_of_launch; ++i, offset += aligned_size_limit) + { + const auto current_size = std::min(size - offset, aligned_size_limit); + const auto current_blocks = (current_size + items_per_block - 1) / items_per_block; + + if(debug_synchronous) + { + start = std::chrono::steady_clock::now(); + } + + detail::transform_kernel + <<>>(input + offset, + current_size, + output + offset, + transform_op); + + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("transform_kernel", current_size, start); + } + + return hipSuccess; +} + +} // namespace detail /// \brief Parallel transform primitive for device level. /// @@ -123,65 +203,16 @@ inline hipError_t transform(InputIterator input, const hipStream_t stream = 0, bool debug_synchronous = false) { - if( size == size_t(0) ) - return hipSuccess; - - using input_type = typename std::iterator_traits::value_type; - using result_type = typename ::rocprim::invoke_result::type; - - using config = detail::wrapped_transform_config; - - detail::target_arch target_arch; - hipError_t result = detail::host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const detail::transform_config_params params - = detail::dispatch_target_arch(target_arch); + constexpr bool is_pointer + = std::is_pointer::value && std::is_pointer::value; - const unsigned int block_size = params.kernel_config.block_size; - const unsigned int items_per_thread = params.kernel_config.items_per_thread; - const auto items_per_block = block_size * items_per_thread; - - // Start point for time measurements - std::chrono::steady_clock::time_point start; - - const auto size_limit = params.kernel_config.size_limit; - const auto number_of_blocks_limit = ::rocprim::max(size_limit / items_per_block, 1); - - auto number_of_blocks = (size + items_per_block - 1)/items_per_block; - if(debug_synchronous) - { - std::cout << "block_size " << block_size << '\n'; - std::cout << "number of blocks " << number_of_blocks << '\n'; - std::cout << "number of blocks limit " << number_of_blocks_limit << '\n'; - std::cout << "items_per_block " << items_per_block << '\n'; - } - - const auto aligned_size_limit = number_of_blocks_limit * items_per_block; - - // Launch number_of_blocks_limit blocks while there is still at least as many blocks left as the limit - const auto number_of_launch = (size + aligned_size_limit - 1) / aligned_size_limit; - for(size_t i = 0, offset = 0; i < number_of_launch; ++i, offset += aligned_size_limit) { - const auto current_size = std::min(size - offset, aligned_size_limit); - const auto current_blocks = (current_size + items_per_block - 1) / items_per_block; - - if(debug_synchronous) - start = std::chrono::steady_clock::now(); - hipLaunchKernelGGL(HIP_KERNEL_NAME(detail::transform_kernel), - dim3(current_blocks), - dim3(block_size), - 0, - stream, - input + offset, - current_size, - output + offset, - transform_op); - ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("transform_kernel", current_size, start); - } - - return hipSuccess; + return detail::transform_impl( + input, + output, + size, + transform_op, + stream, + debug_synchronous); } /// \brief Parallel device-level transform primitive for two inputs. @@ -241,32 +272,104 @@ inline hipError_t transform(InputIterator input, /// // output: [2, 4, 6, 8, 10, 12, 14, 16] /// \endcode /// \endparblock -template< - class Config = default_config, - class InputIterator1, - class InputIterator2, - class OutputIterator, - class BinaryFunction -> -inline -hipError_t transform(InputIterator1 input1, - InputIterator2 input2, - OutputIterator output, - const size_t size, - BinaryFunction transform_op, - const hipStream_t stream = 0, - bool debug_synchronous = false) +template +inline hipError_t transform(InputIterator1 input1, + InputIterator2 input2, + OutputIterator output, + const size_t size, + BinaryFunction transform_op, + const hipStream_t stream = 0, + bool debug_synchronous = false) { using value_type1 = typename std::iterator_traits::value_type; using value_type2 = typename std::iterator_traits::value_type; return transform( - ::rocprim::make_zip_iterator(::rocprim::make_tuple(input1, input2)), output, - size, detail::unpack_binary_op(transform_op), - stream, debug_synchronous - ); + ::rocprim::make_zip_iterator(::rocprim::make_tuple(input1, input2)), + output, + size, + detail::unpack_binary_op(transform_op), + stream, + debug_synchronous); } - +/// \brief Parallel device-level transform primitive for an arbitrary amount of inputs. +/// +/// transform function performs a device-wide transformation operation +/// on n input ranges using binary \p transform_op operator. +/// +/// \par Overview +/// * Ranges specified by \p output and all iterators in \p input_iters must have at least \p size elements. +/// +/// \tparam Config [optional] Configuration of the primitive, must be `default_config` or `transform_config`. +/// \tparam InputIterators all the random-access iterator types of the input range. These types must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator random-access iterator type of the output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction type of binary function used for transform. +/// +/// \param [in] input_iters a tuple of iterators to the input sequences where num_items elements are read from each. +/// \param [out] output iterator to the first element in the output range. +/// \param [in] size number of element in the input range. +/// \param [in] transform_op an n-ary function object used for the transform, where n is the number of input sequences. +/// The function object must not modify the object passed to it. +/// \param [in] stream [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced. Default value is \p false. +/// +/// \par Example +/// \parblock +/// In this example a device-level transform operation is performed on three arrays of +/// integer values (element-wise sum is performed). +/// +/// \code{.cpp} +/// #include +/// +/// // custom transform function +/// auto transform_op = +/// [] __device__ (int a, int b, int c) -> int +/// { +/// return a + b + c; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t size; // e.g., 8 +/// int* input1; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int* input2; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int* input3; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int* output; // empty array of 8 elements +/// +/// // perform transform +/// rocprim::transform( +/// rocprim::tuple(input1, input2, input3), output, input1.size(), transform_op +/// ); +/// // output: [3, 6, 9, 12, 15, 18, 21, 24] +/// \endcode +/// \endparblock +template +inline hipError_t transform(rocprim::tuple input_iters, + OutputIterator output, + const size_t size, + TransformOp transform_op, + const hipStream_t stream = 0, + bool debug_synchronous = false) +{ + return transform( + ::rocprim::make_zip_iterator(input_iters), + output, + size, + detail::unpack_nary_op::value_type...>( + transform_op), + stream, + debug_synchronous); +} END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/device_transform_config.hpp b/rocprim/include/rocprim/device/device_transform_config.hpp index c3de13b48..a74bc163f 100644 --- a/rocprim/include/rocprim/device/device_transform_config.hpp +++ b/rocprim/include/rocprim/device/device_transform_config.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,10 +24,11 @@ #include #include "../config.hpp" -#include "../functional.hpp" #include "../detail/various.hpp" +#include "../functional.hpp" #include "detail/config/device_transform.hpp" +#include "detail/config/device_transform_pointer.hpp" #include "detail/device_config_helper.hpp" /// \addtogroup primitivesmodule_deviceconfigs @@ -38,7 +39,7 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template +template struct wrapped_transform_config { static_assert(std::is_base_of::value, @@ -52,7 +53,18 @@ struct wrapped_transform_config }; template -struct wrapped_transform_config +struct wrapped_transform_config +{ + template + struct architecture_config + { + static constexpr transform_config_params params + = default_transform_pointer_config(Arch), Value>{}; + }; +}; + +template +struct wrapped_transform_config { template struct architecture_config @@ -63,15 +75,20 @@ struct wrapped_transform_config }; #ifndef DOXYGEN_SHOULD_SKIP_THIS -template +template +template +constexpr transform_config_params + wrapped_transform_config::architecture_config::params; + +template template constexpr transform_config_params - wrapped_transform_config::architecture_config::params; + wrapped_transform_config::architecture_config::params; template template constexpr transform_config_params - wrapped_transform_config::architecture_config::params; + wrapped_transform_config::architecture_config::params; #endif // DOXYGEN_SHOULD_SKIP_THIS } // end namespace detail diff --git a/rocprim/include/rocprim/intrinsics/arch.hpp b/rocprim/include/rocprim/intrinsics/arch.hpp index c6220d9e1..9bca31c99 100644 --- a/rocprim/include/rocprim/intrinsics/arch.hpp +++ b/rocprim/include/rocprim/intrinsics/arch.hpp @@ -23,6 +23,9 @@ #include "../config.hpp" +#include +#include + BEGIN_ROCPRIM_NAMESPACE /// \brief Utilities to query architecture details. @@ -36,12 +39,11 @@ namespace wavefront /// \brief Return the number of threads in the wavefront. /// /// This function is not `constexpr`. - ROCPRIM_DEVICE ROCPRIM_INLINE -unsigned int size() +ROCPRIM_DEVICE ROCPRIM_INLINE +unsigned int size() noexcept { - // This function is **not** constexpr because it will - // be using '__builtin_amdgcn_wavefrontsize()'. - return ROCPRIM_WAVEFRONT_SIZE; + // Note: this function is **not** constexpr! + return __builtin_amdgcn_wavefrontsize(); } /// \brief Return the minimum number of threads in the wavefront. @@ -65,8 +67,11 @@ unsigned int size() ROCPRIM_HOST_DEVICE ROCPRIM_INLINE constexpr unsigned int min_size() { -#if __HIP_DEVICE_COMPILE__ - return ROCPRIM_WAVEFRONT_SIZE; +#if __HIP_DEVICE_COMPILE__ && !__SPIRV__ + #if ROCPRIM_NAVI + return 32u; + #endif + return 64u; #else return ROCPRIM_WARP_SIZE_32; #endif @@ -92,16 +97,203 @@ constexpr unsigned int min_size() ROCPRIM_HOST_DEVICE ROCPRIM_INLINE constexpr unsigned int max_size() { -#if __HIP_DEVICE_COMPILE__ - return ROCPRIM_WAVEFRONT_SIZE; +#if __HIP_DEVICE_COMPILE__ && !__SPIRV__ + return min_size(); #else return ROCPRIM_WARP_SIZE_64; #endif } + +/// \brief Enumeration of possible wavefront hardware targets. +enum class target +{ + /// Target hardware wavefront of size 32. + size32, + /// Target hardware wavefront of size 64. + size64, + /// Target hardware wavefront of unknown size. This is + /// the case when targeting SPIR-V. Use \p target::size32 + /// and \p target::size64 to target a specific hardware + /// wavefront size. + dynamic, +}; + +/// \brief Returns the hardware wavefront size of the current +/// compile target. +/// +/// On host this will return \p target::dynamic. On device +/// this return \p target::size32, \p target::size64, or +/// when targeting SPIR-V \p target::dynamic. +constexpr target get_target() noexcept +{ +#if !defined(__HIP_DEVICE_COMPILE__) || defined(__SPIRV__) + // SPIR-V and host both have unknown compile size. + return target::dynamic; +#else + // The wavefront size is exactly known. + static_assert(min_size() == max_size()); + + if constexpr(min_size() == ROCPRIM_WARP_SIZE_32) + { + return target::size32; + } + return target::size64; +#endif +} + +/// \brief Returns the numerical wavefront size from a +/// given \p rocprim::arch::wavefront::target. +/// +/// This function has no implementation for +/// \p target::dynamic. +template +constexpr unsigned int size_from_target() = delete; + +// Doxygen should ignore the specializations. +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template<> +constexpr unsigned int size_from_target() +{ + return ROCPRIM_WARP_SIZE_32; +} +template<> +constexpr unsigned int size_from_target() +{ + return ROCPRIM_WARP_SIZE_64; +} +#endif + }; // namespace wavefront } // namespace arch +namespace detail +{ + +/// \brief Utility to quickly enable specialization for dynamic +/// wavefront targets. +template<::rocprim::arch::wavefront::target Target> +using wave_target_guard_t = std::enable_if_t; + +template +struct dispatch_wave_size +{ + union storage_type + { + typename Impl32::storage_type wave32; + typename Impl64::storage_type wave64; + }; + + template + ROCPRIM_HOST_DEVICE + auto operator()(F exec, Args&&... args) + { + // Select either the wave32 or wave64 implementation. + auto select = [&](auto impl) -> decltype(auto) + { + // Given an implementation, execute our callback and + // pass our re-mapped arguments. The re-mapping selects + // the appropiate backing 'storage_type' for the chosen + // implementation. + return exec( + // Pass the selected implementation to the callback. + impl, + // Map over every argument in the varadic packing... + [](auto&& arg) -> decltype(auto) + { + // If the argument is 'storage_type'... + if constexpr(std::is_same_v< + // std::remove_cvref is C++20 + std::remove_cv_t>, + storage_type>) + { // And we have a wave32 implementation... + if constexpr(std::is_same_v) + { // We return the wave32 backing storage! + return std::forward(arg.wave32); + } + else + { // And otherwise the wave64 backing storage! + return std::forward(arg.wave64); + } + } + else + { // Otherwise, pass argument transparently. + return std::forward(arg); + } + }(args)...); + }; + + // Now do the actual implementation selection. The compiler + // *should* optimize this after lowering, but the extra + // allocated shared memory due to union is unrecoverable. + if(::rocprim::arch::wavefront::size() == ROCPRIM_WARP_SIZE_64) + { + return select(Impl64{}); + } + else + { + return select(Impl32{}); + } + } +}; + +/// \brief Utility function to assert the wavefront size. +/// +/// Assertion is done either at runtime if we are curenntly +/// compiling for SPIR-V, or if the target is dynamic. +/// Otherwise, we use a static assert. +template<::rocprim::arch::wavefront::target Target> +struct check_wave_size +{ + /// \brief The assertion to do. + template + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE + constexpr void + operator()(P predicate) const + { +#if !defined(__HIP_DEVICE_COMPILE__) || ROCPRIM_TARGET_SPIRV + // When a dynamic wavefront size specializes, we actually + // don't know if the type is valid or not. + assert(predicate(::rocprim::arch::wavefront::size())); +#else + // If we are on device, we do want to statically assert, if possible! + static_assert(predicate(::rocprim::arch::wavefront::size_from_target())); +#endif + // On release builds, assert is no-op, so it will complain + // about unused parameters... + (void)predicate; + } +}; + +/// \brief Short alias to check if the virtual wavefront size fits on +/// the current or specified target. +template +ROCPRIM_INLINE ROCPRIM_HOST_DEVICE +void check_virtual_wave_size() +{ + check_wave_size{}([](unsigned int size) constexpr { return VirtualWaveSize <= size; }); +} + +template<> +struct check_wave_size<::rocprim::arch::wavefront::target::dynamic> +{ + template + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE + void operator()(P predicate) const + { + // Since we don't know the wavefront size, we have to + // do a runtime query. + assert(predicate(::rocprim::arch::wavefront::size())); + + // On release builds, assert is no-op, so it will complain + // about unused parameters... + (void)predicate; + } +}; + +} // namespace detail + END_ROCPRIM_NAMESPACE #endif diff --git a/rocprim/include/rocprim/intrinsics/atomic.hpp b/rocprim/include/rocprim/intrinsics/atomic.hpp index f614aa856..ba6dc6a0b 100644 --- a/rocprim/include/rocprim/intrinsics/atomic.hpp +++ b/rocprim/include/rocprim/intrinsics/atomic.hpp @@ -170,7 +170,7 @@ namespace detail #define ROCPRIM_ATOMIC_LOAD(inst, mod, wait, ptr) \ asm volatile(inst " %0, %1 " mod "\t\n" wait : "=v"(result) : "v"(ptr) : "memory") -#if ROCPRIM_TARGET_CDNA4 || ROCPRIM_TARGET_CDNA3 +#if ROCPRIM_TARGET_CDNA3 #define ROCPRIM_ATOMIC_LOAD_FLAT(ptr) \ ROCPRIM_ATOMIC_LOAD("flat_load_dwordx4", "sc1", "s_waitcnt vmcnt(0)", ptr) #define ROCPRIM_ATOMIC_LOAD_SHARED(ptr) \ @@ -198,7 +198,10 @@ namespace detail ROCPRIM_ATOMIC_LOAD("ds_read_b128", "", "s_waitcnt lgkmcnt(0)", ptr) // This architecture doesn't support atomics on the global AS. #define ROCPRIM_ATOMIC_LOAD_GLOBAL(ptr) ROCPRIM_ATOMIC_LOAD_FLAT(ptr) -#elif ROCPRIM_TARGET_RDNA3 || ROCPRIM_TARGET_CDNA2 || ROCPRIM_TARGET_CDNA1 || ROCPRIM_TARGET_GCN5 +#elif ROCPRIM_TARGET_RDNA3 || ROCPRIM_TARGET_CDNA2 || ROCPRIM_TARGET_CDNA1 || ROCPRIM_TARGET_GCN5 \ + || ROCPRIM_TARGET_SPIRV + // We don't really know what architecture we are on when targeting + // SPIR-V. Lets just assume it's one of these. #define ROCPRIM_ATOMIC_LOAD_FLAT(ptr) \ ROCPRIM_ATOMIC_LOAD("flat_load_dwordx4", "glc", "s_waitcnt vmcnt(0)", ptr) #define ROCPRIM_ATOMIC_LOAD_SHARED(ptr) \ @@ -211,8 +214,8 @@ namespace detail #endif #ifdef __HIP_DEVICE_COMPILE__ - #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_is_shared) \ - && __has_builtin(__builtin_amdgcn_is_private) + #if !ROCPRIM_TARGET_SPIRV && defined(__has_builtin) \ + && __has_builtin(__builtin_amdgcn_is_shared) && __has_builtin(__builtin_amdgcn_is_private) auto* ptr = (const __attribute__((address_space(0 /*flat*/))) __uint128_t*)address; if(__builtin_amdgcn_is_shared(ptr)) @@ -232,6 +235,8 @@ namespace detail ROCPRIM_ATOMIC_LOAD_GLOBAL(global_ptr); } #else + // SPIR-V does not like the address-space checks. For now + // lets just do flat loading/storing. ROCPRIM_ATOMIC_LOAD_FLAT(address); #endif #else @@ -280,7 +285,7 @@ namespace detail #define ROCPRIM_ATOMIC_STORE(inst, mod, wait, ptr) \ asm volatile(inst " %0, %1 " mod "\t\n" wait : : "v"(ptr), "v"(value) : "memory") -#if ROCPRIM_TARGET_CDNA4 || ROCPRIM_TARGET_CDNA3 +#if ROCPRIM_TARGET_CDNA3 #define ROCPRIM_ATOMIC_STORE_FLAT(ptr) \ ROCPRIM_ATOMIC_STORE("flat_store_dwordx4", "sc1", "s_waitcnt vmcnt(0)", ptr) #define ROCPRIM_ATOMIC_STORE_SHARED(ptr) \ @@ -302,7 +307,9 @@ namespace detail // This architecture doesn't support atomics on the global AS. #define ROCPRIM_ATOMIC_STORE_GLOBAL(ptr) ROCPRIM_ATOMIC_STORE_FLAT(ptr) #elif ROCPRIM_TARGET_RDNA3 || ROCPRIM_TARGET_RDNA2 || ROCPRIM_TARGET_RDNA1 || ROCPRIM_TARGET_CDNA2 \ - || ROCPRIM_TARGET_CDNA1 || ROCPRIM_TARGET_GCN5 + || ROCPRIM_TARGET_CDNA1 || ROCPRIM_TARGET_GCN5 || ROCPRIM_TARGET_SPIRV + // We don't really know what architecture we are on when targeting + // SPIR-V. Lets just assume it's one of these. #define ROCPRIM_ATOMIC_STORE_FLAT(ptr) \ ROCPRIM_ATOMIC_STORE("flat_store_dwordx4", "", "s_waitcnt vmcnt(0)", ptr) #define ROCPRIM_ATOMIC_STORE_SHARED(ptr) \ @@ -315,8 +322,8 @@ namespace detail #endif #ifdef __HIP_DEVICE_COMPILE__ - #if defined(__has_builtin) && __has_builtin(__builtin_amdgcn_is_shared) \ - && __has_builtin(__builtin_amdgcn_is_private) + #if !ROCPRIM_TARGET_SPIRV && defined(__has_builtin) \ + && __has_builtin(__builtin_amdgcn_is_shared) && __has_builtin(__builtin_amdgcn_is_private) auto* ptr = (__attribute__((address_space(0 /*flat*/))) __uint128_t*)address; if(__builtin_amdgcn_is_shared(ptr)) @@ -334,6 +341,8 @@ namespace detail ROCPRIM_ATOMIC_STORE_GLOBAL(global_ptr); } #else + // SPIR-V does not like the address-space checks. For now + // lets just do flat loading/storing. ROCPRIM_ATOMIC_STORE_FLAT(address); #endif #else diff --git a/rocprim/include/rocprim/intrinsics/thread.hpp b/rocprim/include/rocprim/intrinsics/thread.hpp index 031232531..99f7a385d 100644 --- a/rocprim/include/rocprim/intrinsics/thread.hpp +++ b/rocprim/include/rocprim/intrinsics/thread.hpp @@ -34,33 +34,6 @@ BEGIN_ROCPRIM_NAMESPACE // Sizes -/// \brief Returns a number of threads in a hardware warp. -/// -/// It is constant for a device. -/// -/// \warning This function will be removed in a future release. -[[deprecated( - "Use the functions provided in 'rocprim::arch::wavefront' instead.")]] -ROCPRIM_HOST_DEVICE -inline constexpr unsigned int warp_size() -{ - return ROCPRIM_WAVEFRONT_SIZE; -} - -/// \brief Returns a number of threads in a hardware warp for the actual target. -/// At device side this constant is available at compile time. -/// -/// It is constant for a device. -/// -/// \warning This function will be removed in a future release. -[[deprecated("Use the functions provided in 'rocprim::arch::wavefront' " - "instead.")]] -ROCPRIM_DEVICE ROCPRIM_INLINE -constexpr unsigned int device_warp_size() -{ - return ROCPRIM_WAVEFRONT_SIZE; -} - /// \brief Returns flat size of a multidimensional block (tile). ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_size() @@ -305,13 +278,6 @@ namespace detail return lane_id()%LogicalWarpSize; } - template<> - ROCPRIM_DEVICE ROCPRIM_INLINE - unsigned int logical_lane_id() - { - return lane_id(); - } - // Return id of "logical warp" in a block template ROCPRIM_DEVICE ROCPRIM_INLINE @@ -320,13 +286,6 @@ namespace detail return flat_block_thread_id()/LogicalWarpSize; } - template<> - ROCPRIM_DEVICE ROCPRIM_INLINE - unsigned int logical_warp_id() - { - return warp_id(); - } - ROCPRIM_DEVICE ROCPRIM_INLINE void memory_fence_system() { diff --git a/rocprim/include/rocprim/intrinsics/warp.hpp b/rocprim/include/rocprim/intrinsics/warp.hpp index afa15511b..5b80f008e 100644 --- a/rocprim/include/rocprim/intrinsics/warp.hpp +++ b/rocprim/include/rocprim/intrinsics/warp.hpp @@ -21,9 +21,13 @@ #ifndef ROCPRIM_INTRINSICS_WARP_HPP_ #define ROCPRIM_INTRINSICS_WARP_HPP_ +#include "arch.hpp" + #include "../config.hpp" #include "../types.hpp" +#include + #include BEGIN_ROCPRIM_NAMESPACE @@ -50,12 +54,17 @@ ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int masked_bit_count(lane_mask_type x, unsigned int add = 0) { int c; -#if ROCPRIM_WAVEFRONT_SIZE == 32 - c = ::__builtin_amdgcn_mbcnt_lo(x, add); -#else - c = ::__builtin_amdgcn_mbcnt_lo(static_cast(x), add); - c = ::__builtin_amdgcn_mbcnt_hi(static_cast(x >> 32), c); -#endif + c = ::__builtin_amdgcn_mbcnt_lo(static_cast(x), add); + if constexpr(sizeof(lane_mask_type) == 8) + { + // SPIR-V: We assumed 64 threads per wave, but this might not + // be correct. Do an extra check to only do the upper half, when + // there actually is an upper half. + if(::rocprim::arch::wavefront::size() == ROCPRIM_WARP_SIZE_64) + { + c = ::__builtin_amdgcn_mbcnt_hi(static_cast(x >> 32), c); + } + } return c; } @@ -74,7 +83,7 @@ int warp_all(int predicate) return ::__all(predicate); } -} // end detail namespace +} // namespace detail /// \overload /// \brief Group active lanes having the same bits of \p label @@ -93,9 +102,8 @@ int warp_all(int predicate) /// lane i's result includes bit j in the lane mask if lane j is part /// of the same group as lane i, i.e. lane i and j called with the /// same value for label. -ROCPRIM_DEVICE ROCPRIM_INLINE lane_mask_type match_any(unsigned int label, - unsigned int label_bits, - bool valid = true) +ROCPRIM_DEVICE ROCPRIM_INLINE +lane_mask_type match_any(unsigned int label, unsigned int label_bits, bool valid = true) { // Obtain a mask with the threads which are currently active. lane_mask_type peer_mask = ballot(valid); @@ -144,7 +152,8 @@ ROCPRIM_DEVICE ROCPRIM_INLINE lane_mask_type match_any(unsigned int label, /// same value for label. template -ROCPRIM_DEVICE ROCPRIM_INLINE lane_mask_type match_any(unsigned int label, bool valid = true) +ROCPRIM_DEVICE ROCPRIM_INLINE +lane_mask_type match_any(unsigned int label, bool valid = true) { // Dispatch to runtime version return match_any(label, LabelBits, valid); @@ -161,7 +170,8 @@ ROCPRIM_DEVICE ROCPRIM_INLINE lane_mask_type match_any(unsigned int label, bool /// /// \pre The relation specified by \p mask must be symmetric and transitive, in other words: the groups /// should be consistent between threads. -ROCPRIM_DEVICE ROCPRIM_INLINE bool group_elect(lane_mask_type mask) +ROCPRIM_DEVICE ROCPRIM_INLINE +bool group_elect(lane_mask_type mask) { const unsigned int prev_same_count = ::rocprim::masked_bit_count(mask); return prev_same_count == 0 && mask != 0; diff --git a/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp b/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp index f7b79cfd0..7c2a9e8fd 100644 --- a/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp +++ b/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp @@ -21,10 +21,15 @@ #ifndef ROCPRIM_INTRINSICS_WARP_SHUFFLE_HPP_ #define ROCPRIM_INTRINSICS_WARP_SHUFFLE_HPP_ +#include +#include #include #include "../config.hpp" #include "../detail/various.hpp" +#include "../intrinsics/bit.hpp" +#include "../intrinsics/warp.hpp" +#include "../types.hpp" #include "thread.hpp" /// \addtogroup warpmodule @@ -34,7 +39,6 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - template ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if::value && (sizeof(T) % sizeof(int) == 0), T>::type diff --git a/rocprim/include/rocprim/iterator/arg_index_iterator.hpp b/rocprim/include/rocprim/iterator/arg_index_iterator.hpp index 3adc89265..fe07bec79 100644 --- a/rocprim/include/rocprim/iterator/arg_index_iterator.hpp +++ b/rocprim/include/rocprim/iterator/arg_index_iterator.hpp @@ -207,12 +207,6 @@ class arg_index_iterator { offset_ = 0; } - - [[deprecated]] friend std::ostream& operator<<(std::ostream& os, - const arg_index_iterator& /* iter */) - { - return os; - } #endif // DOXYGEN_SHOULD_SKIP_THIS private: diff --git a/rocprim/include/rocprim/iterator/constant_iterator.hpp b/rocprim/include/rocprim/iterator/constant_iterator.hpp index d26d824d1..fe69be8cd 100644 --- a/rocprim/include/rocprim/iterator/constant_iterator.hpp +++ b/rocprim/include/rocprim/iterator/constant_iterator.hpp @@ -203,12 +203,6 @@ class constant_iterator { return distance_to(other) <= 0; } - - [[deprecated]] friend std::ostream& operator<<(std::ostream& os, const constant_iterator& iter) - { - os << "[" << iter.value_ << "]"; - return os; - } #endif // DOXYGEN_SHOULD_SKIP_THIS private: diff --git a/rocprim/include/rocprim/iterator/counting_iterator.hpp b/rocprim/include/rocprim/iterator/counting_iterator.hpp index a89ed491e..3208eb5f5 100644 --- a/rocprim/include/rocprim/iterator/counting_iterator.hpp +++ b/rocprim/include/rocprim/iterator/counting_iterator.hpp @@ -206,12 +206,6 @@ class counting_iterator { return distance_to(other) <= 0; } - - [[deprecated]] friend std::ostream& operator<<(std::ostream& os, const counting_iterator& iter) - { - os << "[" << iter.value_ << "]"; - return os; - } #endif // DOXYGEN_SHOULD_SKIP_THIS private: diff --git a/rocprim/include/rocprim/iterator/discard_iterator.hpp b/rocprim/include/rocprim/iterator/discard_iterator.hpp index dab938d27..eb2c8c8be 100644 --- a/rocprim/include/rocprim/iterator/discard_iterator.hpp +++ b/rocprim/include/rocprim/iterator/discard_iterator.hpp @@ -201,12 +201,6 @@ class discard_iterator { return index_ >= other.index_; } - - [[deprecated]] friend std::ostream& operator<<(std::ostream& os, - const discard_iterator& /* iter */) - { - return os; - } #endif // DOXYGEN_SHOULD_SKIP_THIS private: diff --git a/rocprim/include/rocprim/iterator/reverse_iterator.hpp b/rocprim/include/rocprim/iterator/reverse_iterator.hpp index beaaaf796..a016feaa6 100644 --- a/rocprim/include/rocprim/iterator/reverse_iterator.hpp +++ b/rocprim/include/rocprim/iterator/reverse_iterator.hpp @@ -74,10 +74,7 @@ class reverse_iterator {} /// \brief Constructs a new reverse_iterator using the supplied source. - [[deprecated("The initialisation constructor of 'rocprim::reverse_iterator' will be " - "marked explicit in ROCm 7.0. Use 'rocprim::make_reverse_iterator' " - "instead.")]] ROCPRIM_HOST_DEVICE constexpr /*explicit*/ - reverse_iterator(SourceIterator source_iterator) + ROCPRIM_HOST_DEVICE constexpr explicit reverse_iterator(SourceIterator source_iterator) : source_iterator_(source_iterator) {} @@ -249,10 +246,7 @@ template ROCPRIM_HOST_DEVICE constexpr reverse_iterator make_reverse_iterator(SourceIterator source_iterator) { -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wdeprecated-declarations" return reverse_iterator(source_iterator); -#pragma clang diagnostic pop } END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/iterator/texture_cache_iterator.hpp b/rocprim/include/rocprim/iterator/texture_cache_iterator.hpp index c8b1af5d7..ba61bae45 100644 --- a/rocprim/include/rocprim/iterator/texture_cache_iterator.hpp +++ b/rocprim/include/rocprim/iterator/texture_cache_iterator.hpp @@ -326,12 +326,6 @@ class texture_cache_iterator { return (ptr - other.ptr) <= 0; } - - [[deprecated]] friend std::ostream& operator<<(std::ostream& os, - const texture_cache_iterator& /* iter */) - { - return os; - } #endif // DOXYGEN_SHOULD_SKIP_THIS private: diff --git a/rocprim/include/rocprim/iterator/transform_iterator.hpp b/rocprim/include/rocprim/iterator/transform_iterator.hpp index 6f321bfcc..256d6eaf0 100644 --- a/rocprim/include/rocprim/iterator/transform_iterator.hpp +++ b/rocprim/include/rocprim/iterator/transform_iterator.hpp @@ -205,12 +205,6 @@ class transform_iterator { return iterator_ >= other.iterator_; } - - [[deprecated]] friend std::ostream& operator<<(std::ostream& os, - const transform_iterator& /* iter */) - { - return os; - } #endif // DOXYGEN_SHOULD_SKIP_THIS private: diff --git a/rocprim/include/rocprim/iterator/zip_iterator.hpp b/rocprim/include/rocprim/iterator/zip_iterator.hpp index f0a4cfd09..fb0e22b3f 100644 --- a/rocprim/include/rocprim/iterator/zip_iterator.hpp +++ b/rocprim/include/rocprim/iterator/zip_iterator.hpp @@ -282,11 +282,6 @@ class zip_iterator { return !(*this < other); } - - [[deprecated]] friend std::ostream& operator<<(std::ostream& os, const zip_iterator& /* iter */) - { - return os; - } #endif // DOXYGEN_SHOULD_SKIP_THIS private: diff --git a/rocprim/include/rocprim/rocprim.hpp b/rocprim/include/rocprim/rocprim.hpp index 37f1571a5..6cae29331 100644 --- a/rocprim/include/rocprim/rocprim.hpp +++ b/rocprim/include/rocprim/rocprim.hpp @@ -34,7 +34,6 @@ #include "type_traits.hpp" #include "iterator.hpp" -#include "thread/radix_key_codec.hpp" #include "thread/thread_load.hpp" #include "thread/thread_operators.hpp" #include "thread/thread_reduce.hpp" diff --git a/rocprim/include/rocprim/thread/radix_key_codec.hpp b/rocprim/include/rocprim/thread/radix_key_codec.hpp index 81ebe0758..b72860c3f 100644 --- a/rocprim/include/rocprim/thread/radix_key_codec.hpp +++ b/rocprim/include/rocprim/thread/radix_key_codec.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -32,638 +32,8 @@ #include "../types.hpp" #include "../types/tuple.hpp" -/// \addtogroup threadmodule -/// @{ - -BEGIN_ROCPRIM_NAMESPACE - -namespace detail -{ - -// Encode and decode integral and floating point values for radix sort in such a way that preserves -// correct order of negative and positive keys (i.e. negative keys go before positive ones, -// which is not true for a simple reinterpetation of the key's bits). - -// Digit extractor takes into account that (+0.0 == -0.0) is true for floats, -// so both +0.0 and -0.0 are reflected into the same bit pattern for digit extraction. -// Maximum digit length is 32. - -template -struct radix_key_codec_integral -{}; - -template -struct radix_key_codec_integral::value>::type> -{ - using bit_key_type = BitKey; - - ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key) - { - return ::rocprim::detail::bit_cast(key); - } - - ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key) - { - return ::rocprim::detail::bit_cast(bit_key); - } - - template - ROCPRIM_HOST_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) - { - unsigned int mask = (1u << length) - 1; - return static_cast(bit_key >> start) & mask; - } -}; - -template -struct radix_key_codec_integral::value>::type> -{ - using bit_key_type = BitKey; - - static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1); - - ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key) - { - const auto bit_key = ::rocprim::detail::bit_cast(key); - return sign_bit ^ bit_key; - } - - ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key) - { - bit_key ^= sign_bit; - return ::rocprim::detail::bit_cast(bit_key); - } - - template - ROCPRIM_HOST_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) - { - unsigned int mask = (1u << length) - 1; - return static_cast(bit_key >> start) & mask; - } -}; - -template -struct radix_key_codec_floating -{ - using bit_key_type = BitKey; - - static constexpr bit_key_type sign_bit - = ::rocprim::traits::get().float_bit_mask().sign_bit; - - ROCPRIM_HOST_DEVICE ROCPRIM_INLINE - static bit_key_type encode(Key key) - { - bit_key_type bit_key = ::rocprim::detail::bit_cast(key); - bit_key ^= (sign_bit & bit_key) == 0 ? sign_bit : bit_key_type(-1); - return bit_key; - } - - ROCPRIM_HOST_DEVICE ROCPRIM_INLINE static Key decode(bit_key_type bit_key) - { - bit_key ^= (sign_bit & bit_key) == 0 ? bit_key_type(-1) : sign_bit; - return ::rocprim::detail::bit_cast(bit_key); - } - - template - ROCPRIM_HOST_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) - { - unsigned int mask = (1u << length) - 1; - - // radix_key_codec_floating::encode() maps 0.0 to 0x8000'0000, - // and -0.0 to 0x7FFF'FFFF. - // radix_key_codec::encode() then flips the bits if descending, yielding: - // value | descending | ascending | - // ----- | ----------- | ----------- | - // 0.0 | 0x7FFF'FFFF | 0x8000'0000 | - // -0.0 | 0x8000'0000 | 0x7FFF'FFFF | - // - // For ascending sort, both should be mapped to 0x8000'0000, - // and for descending sort, both should be mapped to 0x7FFF'FFFF. - if ROCPRIM_IF_CONSTEXPR(Descending) - { - bit_key = bit_key == sign_bit ? static_cast(~sign_bit) : bit_key; - } - else - { - bit_key = bit_key == static_cast(~sign_bit) ? sign_bit : bit_key; - } - return static_cast(bit_key >> start) & mask; - } -}; - -template -struct radix_key_codec_base -{ - // Non-fundamental keys (custom keys) will not use any specialization and thus they do not - // have any of the struct members that fundamental types have. -}; - -template -struct radix_key_codec_base::value>::type> - : radix_key_codec_integral::type> -{}; - -template<> -struct radix_key_codec_base -{ - using bit_key_type = unsigned char; - - ROCPRIM_HOST_DEVICE static bit_key_type encode(bool key) - { - return static_cast(key); - } - - ROCPRIM_HOST_DEVICE static bool decode(bit_key_type bit_key) - { - return static_cast(bit_key); - } - - template - ROCPRIM_HOST_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) - { - unsigned int mask = (1u << length) - 1; - return static_cast(bit_key >> start) & mask; - } -}; - -template<> -struct radix_key_codec_base<::rocprim::half> - : radix_key_codec_floating<::rocprim::half, unsigned short> -{}; - -template<> -struct radix_key_codec_base<::rocprim::bfloat16> - : radix_key_codec_floating<::rocprim::bfloat16, unsigned short> -{}; - -template<> -struct radix_key_codec_base : radix_key_codec_floating -{}; - -template<> -struct radix_key_codec_base : radix_key_codec_floating -{}; - -template -struct has_bit_key_type -{ - template - static std::true_type check(typename U::bit_key_type*); - - template - static std::false_type check(...); - - using result = decltype(check(nullptr)); -}; - -template -using radix_key_fundamental = typename has_bit_key_type>::result; - -static_assert(radix_key_fundamental::value, "'int' should be fundamental"); -static_assert(!radix_key_fundamental::value, "'int*' should not be fundamental"); -static_assert(radix_key_fundamental::value, - "'rocprim::int128_t' should be fundamental"); -static_assert(radix_key_fundamental::value, - "'rocprim::uint128_t' should be fundamental"); -static_assert(!radix_key_fundamental::value, - "'rocprim::int128_t*' should not be fundamental"); - -} // namespace detail - -/// \brief Key encoder, decoder and bit-extractor for radix-based sorts. -/// -/// \tparam Key Type of the key used. -/// \tparam Descending Whether the sort is increasing or decreasing. -template::value> -class radix_key_codec : protected ::rocprim::detail::radix_key_codec_base -{ - using base_type = ::rocprim::detail::radix_key_codec_base; - -public: - /// \brief Type of the encoded key. - using bit_key_type = typename base_type::bit_key_type; - - /// \brief Encodes a key of type \p Key into \p bit_key_type. - /// - /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be - /// \p identity_decomposer. This is also the type by default. - /// \param [in] key Key to encode. - /// \param [in] decomposer [optional] Decomposer functor. - /// \return A \p bit_key_type encoded key. - template - ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key, Decomposer decomposer = {}) - { - static_assert(std::is_same::value, - "Fundamental types don't use custom decomposer."); - bit_key_type bit_key = base_type::encode(key); - return Descending ? ~bit_key : bit_key; - } - - /// \brief Encodes in-place a key of type \p Key. - /// - /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be - /// \p identity_decomposer. This is also the type by default. - /// \param [in, out] key Key to encode. - /// \param [in] decomposer [optional] Decomposer functor. - template - ROCPRIM_HOST_DEVICE static void encode_inplace(Key& key, Decomposer decomposer = {}) - { - static_assert(std::is_same::value, - "Fundamental types don't use custom decomposer."); - key = ::rocprim::detail::bit_cast(encode(key)); - } - - /// \brief Decodes an encoded key of type \p bit_key_type back into \p Key. - /// - /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be - /// \p identity_decomposer. This is also the type by default. - /// \param [in] bit_key Key to decode. - /// \param [in] decomposer [optional] Decomposer functor. - /// \return A \p Key decoded key. - template - ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key, Decomposer decomposer = {}) - { - static_assert(std::is_same::value, - "Fundamental types don't use custom decomposer."); - bit_key = Descending ? ~bit_key : bit_key; - return base_type::decode(bit_key); - } - - /// \brief Decodes in-place an encoded key of type \p Key. - /// - /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be - /// \p identity_decomposer. This is also the type by default. - /// \param [in, out] key Key to decode. - /// \param [in] decomposer [optional] Decomposer functor. - template - ROCPRIM_HOST_DEVICE static void decode_inplace(Key& key, Decomposer decomposer = {}) - { - static_assert(std::is_same::value, - "Fundamental types don't use custom decomposer."); - key = decode(::rocprim::detail::bit_cast(key)); - } - - /// \brief Extracts the specified bits from a given encoded key. - /// - /// \param [in] bit_key Encoded key. - /// \param [in] start Start bit of the sequence of bits to extract. - /// \param [in] radix_bits How many bits to extract. - /// \return Requested bits from the key. - ROCPRIM_HOST_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int radix_bits) - { - return base_type::template extract_digit(bit_key, start, radix_bits); - } - - /// \brief Extracts the specified bits from a given in-place encoded key. - /// - /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be - /// \p identity_decomposer. This is also the type by default. - /// \param [in] key Key. - /// \param [in] start Start bit of the sequence of bits to extract. - /// \param [in] radix_bits How many bits to extract. - /// \param [in] decomposer [optional] Decomposer functor. - /// \return Requested bits from the key. - template - ROCPRIM_HOST_DEVICE static unsigned int extract_digit(Key key, - unsigned int start, - unsigned int radix_bits, - Decomposer decomposer = {}) - { - static_assert(std::is_same::value, - "Fundamental types don't use custom decomposer."); - return extract_digit(::rocprim::detail::bit_cast(key), start, radix_bits); - } - - /// \brief Gives the default value for out-of-bound keys of type \p Key. - /// - /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be - /// \p identity_decomposer. This is also the type by default. - /// \param [in] decomposer [optional] Decomposer functor. - /// \return Out-of-bound keys' value. - template - ROCPRIM_HOST_DEVICE static Key get_out_of_bounds_key(Decomposer decomposer = {}) - { - static_assert(std::is_same::value, - "Fundamental types don't use custom decomposer."); - return decode(static_cast(-1)); - } -}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specializations -template -class radix_key_codec : protected detail::radix_key_codec_base -{ - using base_type = detail::radix_key_codec_base; - -public: - using bit_key_type = typename base_type::bit_key_type; - - template - ROCPRIM_HOST_DEVICE static bit_key_type encode(bool key, Decomposer decomposer = {}) - { - static_assert(std::is_same::value, - "Fundamental types don't use custom decomposer."); - return Descending != key; - } - - template - ROCPRIM_HOST_DEVICE static void encode_inplace(bool& key, Decomposer decomposer = {}) - { - static_assert(std::is_same::value, - "Fundamental types don't use custom decomposer."); - key = ::rocprim::detail::bit_cast(encode(key)); - } - - template - ROCPRIM_HOST_DEVICE static bool decode(bit_key_type bit_key, Decomposer decomposer = {}) - { - static_assert(std::is_same::value, - "Fundamental types don't use custom decomposer."); - const bool key_value = bit_key; - return Descending != key_value; - } - - template - ROCPRIM_HOST_DEVICE static void decode_inplace(bool& key, Decomposer decomposer = {}) - { - static_assert(std::is_same::value, - "Fundamental types don't use custom decomposer."); - key = decode(::rocprim::detail::bit_cast(key)); - } - - ROCPRIM_HOST_DEVICE static unsigned int - extract_digit(bit_key_type bit_key, unsigned int start, unsigned int radix_bits) - { - return base_type::template extract_digit(bit_key, start, radix_bits); - } - - template - ROCPRIM_HOST_DEVICE static unsigned int extract_digit(bool key, - unsigned int start, - unsigned int radix_bits, - Decomposer decomposer = {}) - { - static_assert(std::is_same::value, - "Fundamental types don't use custom decomposer."); - return extract_digit(::rocprim::detail::bit_cast(key), start, radix_bits); - } - - template - ROCPRIM_HOST_DEVICE static bool get_out_of_bounds_key(Decomposer decomposer = {}) - { - static_assert(std::is_same::value, - "Fundamental types don't use custom decomposer."); - return decode(static_cast(-1)); - } -}; -#endif // DOXYGEN_SHOULD_SKIP_THIS - -/// \brief Key encoder, decoder and bit-extractor for radix-based sorts with custom key types. -/// -/// \tparam Key Type of the key used. -/// \tparam Descending Whether the sort is increasing or decreasing.template -template -class radix_key_codec -{ -public: - /// \brief The key in this case is a custom type, so \p bit_key_type cannot be the type of the - /// encoded key because it depends on the decomposer used. It is thus set as the type \p Key. - using bit_key_type = Key; - - /// \brief Encodes a key of type \p Key into \p bit_key_type. - /// - /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer - /// type must be other than the \p identity_decomposer. - /// \param [in] key Key to encode. - /// \param [in] decomposer [optional] \p Key is a custom key type, so a custom decomposer - /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a - /// \p Key key is needed. - /// \return A \p bit_key_type encoded key. - template - ROCPRIM_HOST_DEVICE static bit_key_type encode(Key key, Decomposer decomposer = {}) - { - encode_inplace(key, decomposer); - return static_cast(key); - } - - /// \brief Encodes in-place a key of type \p Key. - /// - /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer - /// type must be other than the \p identity_decomposer. - /// \param [in, out] key Key to encode. - /// \param [in] decomposer [optional] \p Key is a custom key type, so a custom decomposer - /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a - /// \p Key key is needed. - template - ROCPRIM_HOST_DEVICE static void encode_inplace(Key& key, Decomposer decomposer = {}) - { - static_assert(!std::is_same::value, - "The decomposer of a custom-type key cannot be the identity decomposer."); - static_assert(::rocprim::detail::is_tuple_of_references::value, - "The decomposer must return a tuple of references."); - const auto per_element_encode = [](auto& tuple_element) - { - using element_type_t = std::decay_t; - using codec_t = radix_key_codec; - codec_t::encode_inplace(tuple_element); - }; - ::rocprim::detail::for_each_in_tuple(decomposer(key), per_element_encode); - } - - /// \brief Decodes an encoded key of type \p bit_key_type back into \p Key. - /// - /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer - /// type must be other than the \p identity_decomposer. - /// \param [in] bit_key Key to decode. - /// \param [in] decomposer [optional] \p Key is a custom key type, so a custom decomposer - /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a - /// \p Key key is needed. - /// \return A \p Key decoded key. - template - ROCPRIM_HOST_DEVICE static Key decode(bit_key_type bit_key, Decomposer decomposer = {}) - { - decode_inplace(bit_key, decomposer); - return static_cast(bit_key); - } - - /// \brief Decodes in-place an encoded key of type \p Key. - /// - /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer - /// type must be other than the \p identity_decomposer. - /// \param [in, out] key Key to decode. - /// \param [in] decomposer [optional] Decomposer functor. - template - ROCPRIM_HOST_DEVICE static void decode_inplace(Key& key, Decomposer decomposer = {}) - { - static_assert(!std::is_same::value, - "The decomposer of a custom-type key cannot be the identity decomposer."); - static_assert(::rocprim::detail::is_tuple_of_references::value, - "The decomposer must return a tuple of references."); - const auto per_element_decode = [](auto& tuple_element) - { - using element_type_t = std::decay_t; - using codec_t = radix_key_codec; - codec_t::decode_inplace(tuple_element); - }; - ::rocprim::detail::for_each_in_tuple(decomposer(key), per_element_decode); - } - - /// \brief Extracts the specified bits from a given encoded key. - /// - /// \return Requested bits from the key. - ROCPRIM_HOST_DEVICE static unsigned int extract_digit(bit_key_type, unsigned int, unsigned int) - { - static_assert( - sizeof(bit_key_type) == 0, - "Only fundamental types (integral and floating point) are supported as radix sort" - "keys without a decomposer. " - "For custom key types, use the extract_digit overloads with the decomposer argument"); - } - - /// \brief Extracts the specified bits from a given in-place encoded key. - /// - /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer - /// type must be other than the \p identity_decomposer. - /// \param [in] key Key. - /// \param [in] start Start bit of the sequence of bits to extract. - /// \param [in] radix_bits How many bits to extract. - /// \param [in] decomposer \p Key is a custom key type, so a custom decomposer - /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a - /// \p Key key is needed. - /// \return Requested bits from the key. - template - ROCPRIM_HOST_DEVICE static unsigned int - extract_digit(Key key, unsigned int start, unsigned int radix_bits, Decomposer decomposer) - { - static_assert(!std::is_same::value, - "The decomposer of a custom-type key cannot be the identity decomposer."); - static_assert(::rocprim::detail::is_tuple_of_references::value, - "The decomposer must return a tuple of references."); - constexpr size_t tuple_size - = ::rocprim::tuple_size>::value; - return extract_digit_from_key_impl(0u, - decomposer(key), - static_cast(start), - static_cast(start + radix_bits), - 0); - } - - /// \brief Gives the default value for out-of-bound keys of type \p Key. - /// - /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer - /// type must be other than the \p identity_decomposer. - /// \param [in] decomposer \p Key is a custom key type, so a custom decomposer - /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a - /// \p Key key is needed. - /// \return Out-of-bound keys' value. - template - ROCPRIM_HOST_DEVICE static Key get_out_of_bounds_key(Decomposer decomposer) - { - static_assert(!std::is_same::value, - "The decomposer of a custom-type key cannot be the identity decomposer."); - static_assert(std::is_default_constructible::value, - "The sorted Key type must be default constructible"); - Key key; - ::rocprim::detail::for_each_in_tuple( - decomposer(key), - [](auto& element) - { - using element_t = std::decay_t; - using codec_t = radix_key_codec; - using bit_key_type = typename codec_t::bit_key_type; - element = codec_t::decode(static_cast(-1)); - }); - return key; - } - -private: - template - ROCPRIM_HOST_DEVICE static auto - extract_digit_from_key_impl(unsigned int digit, - const ::rocprim::tuple& key_tuple, - const int start, - const int end, - const int previous_bits) - -> std::enable_if_t<(ElementIndex >= 0), unsigned int> - { - using T = std::decay_t<::rocprim::tuple_element_t>>; - using bit_key_type = typename radix_key_codec::bit_key_type; - constexpr int current_element_bits = 8 * sizeof(T); - - const int total_extracted_bits = end - start; - const int current_element_end_bit = previous_bits + current_element_bits; - if(start < current_element_end_bit && end > previous_bits) - { - // unsigned integral representation of the current tuple element - const auto element_value = ::rocprim::detail::bit_cast( - ::rocprim::get(key_tuple)); - - const int bits_extracted_previously = ::rocprim::max(0, previous_bits - start); - - // start of the bit range copied from the current tuple element - const int current_start_bit = ::rocprim::max(0, start - previous_bits); - - // end of the bit range copied from the current tuple element - const int current_end_bit = ::rocprim::min(current_element_bits, - current_start_bit + total_extracted_bits - - bits_extracted_previously); - - // number of bits extracted from the current tuple element - const int current_length = current_end_bit - current_start_bit; - - // bits extracted from the current tuple element, aligned to LSB - unsigned int current_extract = (element_value >> current_start_bit); - - if(current_length != 32) - { - current_extract &= (1u << current_length) - 1; - } - - digit |= current_extract << bits_extracted_previously; - } - return extract_digit_from_key_impl(digit, - key_tuple, - start, - end, - previous_bits + current_element_bits); - } - - /// - template - ROCPRIM_HOST_DEVICE static auto - extract_digit_from_key_impl(unsigned int digit, - const ::rocprim::tuple& /*key_tuple*/, - const int /*start*/, - const int /*end*/, - const int /*previous_bits*/) - -> std::enable_if_t<(ElementIndex < 0), unsigned int> - { - return digit; - } -}; - -namespace detail -{ - -template -using radix_key_codec [[deprecated("radix_key_codec is now public API.")]] -= rocprim::radix_key_codec; - -} // namespace detail -END_ROCPRIM_NAMESPACE - -/// @} -// end of group threadmodule +ROCPRIM_PRAGMA_MESSAGE("Functionality from rocprim/detail/radix_key_codec.hpp has been moved to " + "rocprim/thread/type_traits.hpp.") +#include "../type_traits.hpp" #endif // ROCPRIM_THREAD_RADIX_KEY_CODEC_HPP_ diff --git a/rocprim/include/rocprim/thread/thread_load.hpp b/rocprim/include/rocprim/thread/thread_load.hpp index 2e96c4c82..0c5cfd065 100644 --- a/rocprim/include/rocprim/thread/thread_load.hpp +++ b/rocprim/include/rocprim/thread/thread_load.hpp @@ -58,8 +58,7 @@ enum cache_load_modifier : int load_cv = 4, ///< Cache as volatile (including cached system lines) load_ldg = 5, ///< Cache as texture load_volatile = 6, ///< Volatile (any memory space) - load_cs = load_nontemporal, ///< Alias for load_nontemporal (will be deprecated in 7.0) - load_count = 8 + load_count = 7 }; /// @} @@ -73,7 +72,7 @@ ROCPRIM_DEVICE ROCPRIM_INLINE T asm_thread_load(void* ptr) { T retval{}; - __builtin_memcpy(&retval, ptr, sizeof(T)); + __builtin_memcpy(static_cast(&retval), ptr, sizeof(T)); return retval; } @@ -155,10 +154,12 @@ std::enable_if_t thread_load(T* ptr) { - alignas(Alignment) T result; - detail::thread_fused_copy(&result, - ptr, - [](auto& dst, const auto& src) { dst = src; }); + using decay_type = typename std::remove_const_t; + alignas(Alignment) decay_type result; + detail::thread_fused_copy(&result, + ptr, + [](auto& dst, const auto& src) + { dst = src; }); return result; } @@ -188,14 +189,16 @@ ROCPRIM_DEVICE ROCPRIM_INLINE std::enable_if_t thread_load(T* ptr) { - alignas(Alignment) T result; - detail::thread_fused_copy(&result, - ptr, - [](auto& dst, const auto& src) - { - using U = std::remove_reference_t; - dst = *static_cast(&src); - }); + using decay_type = typename std::remove_const_t; + alignas(Alignment) decay_type result; + detail::thread_fused_copy( + &result, + ptr, + [](auto& dst, const auto& src) + { + using U = std::remove_reference_t; + dst = *static_cast(&src); + }); return result; } @@ -215,11 +218,12 @@ ROCPRIM_DEVICE ROCPRIM_INLINE std::enable_if_t thread_load(T* ptr) { #if __has_builtin(__builtin_nontemporal_load) - alignas(Alignment) T result; - detail::thread_fused_copy(&result, - ptr, - [](auto& dst, const auto& src) - { dst = __builtin_nontemporal_load(&src); }); + using decay_type = typename std::remove_const_t; + alignas(Alignment) decay_type result; + detail::thread_fused_copy( + &result, + ptr, + [](auto& dst, const auto& src) { dst = __builtin_nontemporal_load(&src); }); return result; #else return thread_load(ptr); diff --git a/rocprim/include/rocprim/thread/thread_reduce.hpp b/rocprim/include/rocprim/thread/thread_reduce.hpp index 97aa1d55a..4e52c970f 100644 --- a/rocprim/include/rocprim/thread/thread_reduce.hpp +++ b/rocprim/include/rocprim/thread/thread_reduce.hpp @@ -1,7 +1,7 @@ /****************************************************************************** * Copyright (c) 2010-2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. - * Modifications Copyright (c) 2021-2024, Advanced Micro Devices, Inc. All rights reserved. + * Modifications Copyright (c) 2021-2025, Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -31,10 +31,8 @@ #define ROCPRIM_THREAD_THREAD_REDUCE_HPP_ #include "../config.hpp" -#include "../functional.hpp" -#include "../type_traits.hpp" -#include +#include BEGIN_ROCPRIM_NAMESPACE @@ -45,46 +43,20 @@ BEGIN_ROCPRIM_NAMESPACE /// @{ /// \brief Carry out a reduction on an array of elements in one thread -/// \tparam LENGTH Length of the array to be reduced +/// \tparam Length Length of the array to be reduced /// \tparam T the input/output type /// \tparam ReductionOp Binary Operation that used to carry out the reduction -/// \tparam NoPrefix Boolean, determining whether to have a initialization value for the reduction accumulator /// \param input [in] Pointer to the first element of the array to be reduced /// \param reduction_op [in] Instance of the reduction operator functor -/// \param prefix [in] Value to be used as prefix, if NoPrefix is false +/// \param prefix [in] Optional value to be used as prefix /// \return Value obtained from reduction of input array -template +template ROCPRIM_DEVICE ROCPRIM_INLINE -auto thread_reduce(T* input, ReductionOp reduction_op, T prefix = T(0)) - -> std::enable_if_t::value, T> +auto thread_reduce(T* input, ReductionOp reduction_op, Prefix prefix = {}) { - T retval; - if(NoPrefix) - { - retval = input[0]; - } - else - { - retval = prefix; - } + T retval = input[0]; - ROCPRIM_UNROLL - for(int i = 0 + NoPrefix; i < LENGTH; ++i) - { - retval = reduction_op(retval, input[i]); - } - - return retval; -} - -/// \cond thread_reduce_specialization -template -ROCPRIM_DEVICE ROCPRIM_INLINE -auto thread_reduce(T* input, ReductionOp reduction_op, T prefix = T{}) - -> std::enable_if_t::value, T> -{ - T retval; - if(NoPrefix) + if constexpr(std::is_same_v) { retval = input[0]; } @@ -94,51 +66,27 @@ auto thread_reduce(T* input, ReductionOp reduction_op, T prefix = T{}) } ROCPRIM_UNROLL - for(int i = 0 + NoPrefix; i < LENGTH; ++i) + for(int i = 1; i < Length; ++i) { retval = reduction_op(retval, input[i]); } return retval; } -/// \endcond /// \brief Carry out a reduction on an array of elements in one thread -/// \tparam LENGTH Length of the array to be reduced +/// \tparam Length Length of the array to be reduced /// \tparam T the input/output type /// \tparam ReductionOp Binary Operation that used to carry out the reduction /// \param input [in] Pointer to the first element of the array to be reduced /// \param reduction_op [in] Instance of the reduction operator functor -/// \param prefix [in] Value to be used as prefix +/// \param prefix [in] Optional value to be used as prefix /// \return Value obtained from reduction of input array -template < - int LENGTH, - typename T, - typename ReductionOp> -ROCPRIM_DEVICE ROCPRIM_INLINE T thread_reduce( - T (&input)[LENGTH], - ReductionOp reduction_op, - T prefix) -{ - return thread_reduce((T*)input, reduction_op, prefix); -} - -/// \brief Carry out a reduction on an array of elements in one thread -/// \tparam LENGTH Length of the array to be reduced -/// \tparam T the input/output type -/// \tparam ReductionOp Binary Operation that used to carry out the reduction -/// \param input [in] Pointer to the first element of the array to be reduced -/// \param reduction_op [in] Instance of the reduction operator functor -/// \return Value obtained from reduction of input array -template < - int LENGTH, - typename T, - typename ReductionOp> -ROCPRIM_DEVICE ROCPRIM_INLINE T thread_reduce( - T (&input)[LENGTH], - ReductionOp reduction_op) +template +ROCPRIM_DEVICE ROCPRIM_INLINE +T thread_reduce(T (&input)[Length], ReductionOp reduction_op, Prefix prefix = {}) { - return thread_reduce((T*)input, reduction_op); + return thread_reduce(static_cast(input), reduction_op, prefix); } /// @} diff --git a/rocprim/include/rocprim/thread/thread_scan.hpp b/rocprim/include/rocprim/thread/thread_scan.hpp index 4ddd46f86..f9d35fd2d 100644 --- a/rocprim/include/rocprim/thread/thread_scan.hpp +++ b/rocprim/include/rocprim/thread/thread_scan.hpp @@ -1,7 +1,7 @@ /****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. - * Modifications Copyright (c) 2021-2024, Advanced Micro Devices, Inc. All rights reserved. + * Modifications Copyright (c) 2021-2025, Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -195,6 +195,7 @@ ROCPRIM_DEVICE ROCPRIM_INLINE T if(apply_prefix) { inclusive = scan_op(prefix, inclusive); + asm volatile(""); // TEMP FIX. } output[0] = inclusive; // Continue scan diff --git a/rocprim/include/rocprim/thread/thread_store.hpp b/rocprim/include/rocprim/thread/thread_store.hpp index 8f14bf475..ec945b7e8 100644 --- a/rocprim/include/rocprim/thread/thread_store.hpp +++ b/rocprim/include/rocprim/thread/thread_store.hpp @@ -54,8 +54,7 @@ enum cache_store_modifier store_nontemporal = 3, ///< Cache streaming (likely not to be accessed again after storing) store_wt = 4, ///< Cache write-through (to system memory) store_volatile = 5, ///< Volatile (any memory space) - store_cs = store_nontemporal, ///< Alias for store_nontemporal (will be deprecated in 7.0) - store_count = 7 + store_count = 6 }; /// @} diff --git a/rocprim/include/rocprim/type_traits.hpp b/rocprim/include/rocprim/type_traits.hpp index 98c8d7b3b..aaaaa98a4 100644 --- a/rocprim/include/rocprim/type_traits.hpp +++ b/rocprim/include/rocprim/type_traits.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -21,545 +21,1392 @@ #ifndef ROCPRIM_TYPE_TRAITS_HPP_ #define ROCPRIM_TYPE_TRAITS_HPP_ -#include "config.hpp" -#include "functional.hpp" +#include "type_traits_functions.hpp" +#include "types.hpp" +#include -#include "type_traits_interface.hpp" +// common macros -#include "types/tuple.hpp" - -#include -#include - -/// \addtogroup utilsmodule_typetraits -/// @{ +/// \brief A reverse version of static_assert aims to increase code readability +#ifndef ROCPRIM_DO_NOT_COMPILE_IF + #define ROCPRIM_DO_NOT_COMPILE_IF(condition, msg) static_assert(!(condition), msg) +#endif +/// \brief Wrapper macro for std::enable_if aims to increase code readability +#ifndef ROCPRIM_REQUIRES + #define ROCPRIM_REQUIRES(...) typename std::enable_if<(__VA_ARGS__)>::type* = nullptr +#endif +#ifndef DOXYGEN_DOCUMENTATION_BUILD + /// \brief Since every definable traits need to use `is_defined`, this macro reduce the amount of code + #define ROCPRIM_TRAITS_GENERATE_IS_DEFINE(traits_name) \ + template \ + static constexpr bool is_defined = false; \ + template \ + static constexpr bool \ + is_defined::traits_name>> \ + = true +#endif BEGIN_ROCPRIM_NAMESPACE -/// \brief Extension of `std::make_unsigned`, which includes support for 128-bit integers. +namespace traits +{ +/// \par Overview +/// This template struct provides an interface for downstream libraries to implement type traits for +/// their custom types. Users can utilize this template struct to define traits for these types. Users +/// should only implement traits as required by specific algorithms, and some traits cannot be defined +/// if they can be inferred from others. This API is not static because of ODR. +/// \tparam T The type for which you want to define traits. +/// +/// \par Example +/// \parblock +/// The example below demonstrates how to implement traits for a custom floating-point type. +/// \code{.cpp} +/// // Your type definition +/// struct custom_float_type +/// {}; +/// // Implement the traits +/// template<> +/// struct rocprim::traits::define +/// { +/// using is_arithmetic = rocprim::traits::is_arithmetic::values; +/// using number_format = rocprim::traits::number_format::values; +/// using float_bit_mask = rocprim::traits::float_bit_mask::values; +/// }; +/// \endcode +/// The example below demonstrates how to implement traits for a custom integral type. +/// \code{.cpp} +/// // Your type definition +/// struct custom_int_type +/// {}; +/// // Implement the traits +/// template<> +/// struct rocprim::traits::define +/// { +/// using is_arithmetic = rocprim::traits::is_arithmetic::values; +/// using number_format = rocprim::traits::number_format::values; +/// using integral_sign = rocprim::traits::integral_sign::values; +/// }; +/// \endcode +/// \endparblock template -struct make_unsigned : std::make_unsigned +struct define {}; -#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specialized versions -template<> -struct make_unsigned<::rocprim::int128_t> +/// \par Definability +/// * **Undefinable**: For types with `predefined traits`. +/// * **Optional**: For other types. +/// \par How to define +/// \parblock +/// \code{.cpp} +/// using is_arithmetic = rocprim::traits::is_arithmetic::values; +/// \endcode +/// \endparblock +/// \par How to use +/// \parblock +/// \code{.cpp} +/// rocprim::traits::get().is_arithmetic(); +/// \endcode +/// \endparblock +struct is_arithmetic { - using type = ::rocprim::uint128_t; -}; - -template<> -struct make_unsigned<::rocprim::uint128_t> -{ - using type = ::rocprim::uint128_t; -}; -#endif + /// \brief Value of this trait + template + struct values + { + /// \brief This indicates if the `InputType` is arithmetic. + static constexpr auto value = Val; + }; -static_assert(std::is_same::type, ::rocprim::uint128_t>::value, - "'rocprim::int128_t' needs to implement 'make_unsigned' trait."); +#ifndef DOXYGEN_DOCUMENTATION_BUILD -/// \brief Extension of `std::numeric_limits`, which includes support for 128-bit integers. -template -struct numeric_limits : std::numeric_limits -{}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specialized versions -template<> -struct numeric_limits : std::numeric_limits -{ - static constexpr int digits = 128; - static constexpr int digits10 = 38; + ROCPRIM_TRAITS_GENERATE_IS_DEFINE(is_arithmetic); - static constexpr rocprim::uint128_t max() + // For c++ arithmetic types, return true, but will throw compile error when user try to define this trait for them + template::value)> + static constexpr auto get() { - return rocprim::int128_t{-1}; + ROCPRIM_DO_NOT_COMPILE_IF(is_defined, + "Do not define trait `is_arithmetic` for c++ arithmetic types"); + return values{}; } - static constexpr rocprim::uint128_t min() + // For third party types, if trait `is_arithmetic` not defined, will return default value `false` + template::value && !is_defined)> + static constexpr auto get() { - return rocprim::uint128_t{0}; + return values{}; } - static constexpr rocprim::uint128_t lowest() + // For third party types, if trait `is_arithmetic` is defined, then should return its value + template::value && is_defined)> + static constexpr auto get() { - return min(); + return typename define::is_arithmetic{}; } +#endif }; -template<> -struct numeric_limits : std::numeric_limits +/// \brief Arithmetic types, pointers, member pointers, and null pointers are considered scalar types. +/// \par Definability +/// * **Undefinable**: For types with `predefined traits`. +/// * **Optional**: For other types. If both `is_arithmetic` and `is_scalar` are defined, their values +/// must be consistent; otherwise, a compile-time error will occur. +/// \par How to define +/// \parblock +/// \code{.cpp} +/// using is_scalar = rocprim::traits::is_scalar::values; +/// \endcode +/// \endparblock +/// \par How to use +/// \parblock +/// \code{.cpp} +/// rocprim::traits::get().is_scalar(); +/// \endcode +/// \endparblock +struct is_scalar { - static constexpr int digits = 127; - static constexpr int digits10 = 38; + /// \brief Value of this trait + template + struct values + { + /// \brief This indicates if the `InputType` is scalar. + static constexpr auto value = Val; + }; + +#ifndef DOXYGEN_DOCUMENTATION_BUILD + + ROCPRIM_TRAITS_GENERATE_IS_DEFINE(is_scalar); - static constexpr rocprim::int128_t max() + // For c++ scalar types, return true, but will throw compile error when user try to define this trait for them + template::value)> + static constexpr auto get() { - return numeric_limits::max() >> 1; + ROCPRIM_DO_NOT_COMPILE_IF(is_defined, + "Do not define trait `is_scalar` for c++ scalar types"); + return values{}; } - static constexpr rocprim::int128_t min() + // For third party types, if trait `is_scalar` is not defined, will return default value `false` + // For rocprim or third party types that defined trait `is_arithmetic` as true the result should be `true` + template::value && !is_defined)> + static constexpr auto get() { - return -numeric_limits::max() - 1; + return values().value>{}; } - - static constexpr rocprim::int128_t lowest() + // For third party types and rocprim types, if trait `is_scalar` is defined, will return the value + // check if the `is_scalar` equals to `is_arithmetic`, or throw a compile error + template::value && is_defined)> + static constexpr auto get() { - return min(); + ROCPRIM_DO_NOT_COMPILE_IF( + is_arithmetic::get().value != typename define::is_scalar{}.value, + "Trait `is_arithmetic` and trait `is_scalar` should have the same value"); + return typename define::is_scalar{}; } +#endif }; -#endif // DOXYGEN_SHOULD_SKIP_THIS - -/// \brief Used to retrieve a type that can be treated as unsigned version of the template parameter. -/// \tparam T The signed type to find an unsigned equivalent for. -/// \tparam size the desired size (in bytes) of the unsigned type -template -struct get_unsigned_bits_type +/// \par Definability +/// * **Undefinable**: For types with `predefined traits` and non-arithmetic types. +/// * **Required**: If you define `is_arithmetic` as `true`, you must also define this trait; otherwise, a +/// compile-time error will occur. +/// \par How to define +/// \parblock +/// \code{.cpp} +/// using number_format = rocprim::traits::number_format::values; +/// \endcode +/// \endparblock +/// \par How to use +/// \parblock +/// \code{.cpp} +/// rocprim::traits::get().is_integral(); +/// rocprim::traits::get().is_floating_point(); +/// \endcode +/// \endparblock +struct number_format { - using unsigned_type = typename get_unsigned_bits_type:: - unsigned_type; ///< Typedefed to the unsigned type. -}; + /// \brief The kind enum that indecates the values avaliable for this trait + enum class kind + { + unknown_type = 0, + floating_point_type = 1, + integral_type = 2 + }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specialized versions -template -struct get_unsigned_bits_type -{ - using unsigned_type = uint8_t; -}; + /// \brief Value of this trait + template + struct values + { + /// \brief This indicates if the `InputType` is floating_point_type or integral_type or unknown_type. + static constexpr auto value = Val; + }; -template -struct get_unsigned_bits_type -{ - using unsigned_type = uint16_t; -}; +#ifndef DOXYGEN_DOCUMENTATION_BUILD -template -struct get_unsigned_bits_type -{ - using unsigned_type = uint32_t; -}; + ROCPRIM_TRAITS_GENERATE_IS_DEFINE(number_format); -template -struct get_unsigned_bits_type -{ - using unsigned_type = uint64_t; -}; + // For c++ arithmetic types + template::value)> + static constexpr auto get() + { // C++ build-in arithmetic types are either floating point or integral + return values < std::is_floating_point::value ? kind::floating_point_type + : kind::integral_type > {}; + } -template -struct get_unsigned_bits_type -{ - using unsigned_type = ::rocprim::uint128_t; -}; -#endif // DOXYGEN_SHOULD_SKIP_THIS - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -[[deprecated("TwiddleIn is deprecated." - "Use radix_key_codec instead.")]] ROCPRIM_DEVICE ROCPRIM_INLINE auto - TwiddleIn(UnsignedBits key) -> - typename std::enable_if::value, UnsignedBits>::type -{ - static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); - UnsignedBits mask = (key & HIGH_BIT) ? UnsignedBits(-1) : HIGH_BIT; - return key ^ mask; -} - -template -[[deprecated("TwiddleIn is deprecated." - "Use radix_key_codec instead.")]] static ROCPRIM_DEVICE ROCPRIM_INLINE auto - TwiddleIn(UnsignedBits key) -> - typename std::enable_if::value, UnsignedBits>::type -{ - return key ; -}; + // For rocprim arithmetic types + template::value + && is_arithmetic::get().value)> + static constexpr auto get() + { + ROCPRIM_DO_NOT_COMPILE_IF(!is_defined, + "You must define trait `number_format` for arithmetic types"); + return typename define::number_format{}; + } -template -[[deprecated("TwiddleIn is deprecated." - "Use radix_key_codec instead.")]] static ROCPRIM_DEVICE ROCPRIM_INLINE auto - TwiddleIn(UnsignedBits key) -> - typename std::enable_if::value && is_signed::value, UnsignedBits>::type -{ - static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); - return key ^ HIGH_BIT; + // For other types + template::value + && !is_arithmetic::get().value)> + static constexpr auto get() + { + ROCPRIM_DO_NOT_COMPILE_IF( + is_defined, + "You cannot define trait `number_format` for non-arithmetic types"); + return values{}; + } +#endif }; -template -[[deprecated("TwiddleOut is deprecated." - "Use radix_key_codec instead.")]] ROCPRIM_DEVICE ROCPRIM_INLINE auto - TwiddleOut(UnsignedBits key) -> - typename std::enable_if::value, UnsignedBits>::type +/// \par Definability +/// * **Undefinable**: For types with `predefined traits`, non-arithmetic types and floating-point types. +/// * **Required**: If you define `number_format` as `number_format::kind::floating_point_type`, you must also define this trait; otherwise, a +/// compile-time error will occur. +/// \par How to define +/// \parblock +/// \code{.cpp} +/// using integral_sign = rocprim::traits::integral_sign::values; +/// \endcode +/// \endparblock +/// \par How to use +/// \parblock +/// \code{.cpp} +/// rocprim::traits::get().is_signed(); +/// rocprim::traits::get().is_unsigned(); +/// \endcode +/// \endparblock +struct integral_sign { - static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); - UnsignedBits mask = (key & HIGH_BIT) ? HIGH_BIT : UnsignedBits(-1); - return key ^ mask; -} - -template -[[deprecated("TwiddleOut is deprecated." - "Use radix_key_codec instead.")]] static ROCPRIM_DEVICE ROCPRIM_INLINE auto - TwiddleOut(UnsignedBits key) -> - typename std::enable_if::value, UnsignedBits>::type -{ - return key; -}; + /// \brief The kind enum that indecates the values avaliable for this trait + enum class kind + { + unknown_type = 0, + signed_type = 1, + unsigned_type = 2 + }; -template -[[deprecated("TwiddleOut is deprecated." - "Use radix_key_codec instead.")]] static ROCPRIM_DEVICE ROCPRIM_INLINE auto - TwiddleOut(UnsignedBits key) -> - typename std::enable_if::value && is_signed::value, UnsignedBits>::type -{ - static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((sizeof(UnsignedBits) * 8) - 1); - return key ^ HIGH_BIT; -}; -#endif // DOXYGEN_SHOULD_SKIP_THIS + /// \brief Value of this trait + template + struct values + { + /// \brief This indicates if the `InputType` is signed_type or unsigned_type or unknown_type. + static constexpr auto value = Val; + }; -namespace detail -{ +#ifndef DOXYGEN_DOCUMENTATION_BUILD -// invoke_result is based on std::invoke_result. -// The main difference is using ROCPRIM_HOST_DEVICE, this allows to -// use invoke_result with device-only lambdas/functors in host-only functions -// on HIP-clang. + ROCPRIM_TRAITS_GENERATE_IS_DEFINE(integral_sign); -template -struct is_reference_wrapper : std::false_type -{}; -template -struct is_reference_wrapper> : std::true_type -{}; + // For c++ arithmetic types + template::value)> + static constexpr auto get() + { // cpp arithmetic types are either signed point or unsignned + return values < std::is_signed::value ? kind::signed_type + : kind::unsigned_type > {}; + } -template -struct invoke_impl -{ - template - ROCPRIM_HOST_DEVICE static auto call(F&& f, Args&&... args) - -> decltype(std::forward(f)(std::forward(args)...)); -}; + // For rocprim arithmetic integral + template::value && is_arithmetic::get().value + && number_format::get().value == number_format::kind::integral_type)> + static constexpr auto get() + { + ROCPRIM_DO_NOT_COMPILE_IF(!is_defined, + "Trait `integral_sign` is required for arithmetic " + "integral types, please define"); + return typename define::integral_sign{}; + } -template -struct invoke_impl -{ - template::type, - class = typename std::enable_if::value>::type> - ROCPRIM_HOST_DEVICE static auto get(T&& t) -> T&&; - - template::type, - class = typename std::enable_if::value>::type> - ROCPRIM_HOST_DEVICE static auto get(T&& t) -> decltype(t.get()); - - template::type, - class = typename std::enable_if::value>::type, - class = typename std::enable_if::value>::type> - ROCPRIM_HOST_DEVICE static auto get(T&& t) -> decltype(*std::forward(t)); - - template::value>::type> - ROCPRIM_HOST_DEVICE static auto call(MT1 B::*pmf, T&& t, Args&&... args) - -> decltype((invoke_impl::get(std::forward(t)).*pmf)(std::forward(args)...)); + // For rocprim arithmetic non-integral + template::value && is_arithmetic::get().value + && number_format::get().value != number_format::kind::integral_type)> + static constexpr auto get() + { + ROCPRIM_DO_NOT_COMPILE_IF( + is_defined, + "You cannot define trait `integral_sign` for arithmetic non-integral types"); + return values{}; + } - template - ROCPRIM_HOST_DEVICE static auto call(MT B::*pmd, T&& t) - -> decltype(invoke_impl::get(std::forward(t)).*pmd); + // For other types + template::value + && !is_arithmetic::get().value)> + static constexpr auto get() + { // For other types, trait is_floating_point is a must + ROCPRIM_DO_NOT_COMPILE_IF( + is_defined, + "You cannot define trait `integral_sign` for non-arithmetic types"); + return values{}; + } +#endif }; -template::type> -ROCPRIM_HOST_DEVICE auto INVOKE(F&& f, Args&&... args) - -> decltype(invoke_impl::call(std::forward(f), std::forward(args)...)); - -// Conforming C++14 implementation (is also a valid C++11 implementation): -template -struct invoke_result_impl -{}; -template -struct invoke_result_impl(), std::declval()...))), - F, - Args...> +/// \par Definability +/// * **Undefinable**: For types with `predefined traits`, non-arithmetic types and integral types. +/// * **Required**: If you define `number_format` as `number_format::kind::unknown_type`, you must also define this trait; otherwise, a +/// compile-time error will occur. +/// \par How to define +/// \parblock +/// \code{.cpp} +/// using float_bit_mask = rocprim::traits::float_bit_mask::values; +/// \endcode +/// \endparblock +/// \par How to use +/// \parblock +/// \code{.cpp} +/// rocprim::traits::get().float_bit_mask(); +/// \endcode +/// \endparblock +struct float_bit_mask { - using type = decltype(INVOKE(std::declval(), std::declval()...)); -}; + /// \brief Value of this trait + template + struct values + { + ROCPRIM_DO_NOT_COMPILE_IF(number_format::get().value + != number_format::kind::integral_type, + "BitType should be integral"); + /// \brief Trait sign_bit for the `InputType`. + static constexpr BitType sign_bit = SignBit; + /// \brief Trait exponent for the `InputType`. + static constexpr BitType exponent = Exponent; + /// \brief Trait mantissa for the `InputType`. + static constexpr BitType mantissa = Mantissa; + }; -template -struct is_tuple -{ -public: - static constexpr bool value = false; -}; +#ifndef DOXYGEN_DOCUMENTATION_BUILD -template -struct is_tuple<::rocprim::tuple> -{ -private: - template - ROCPRIM_HOST_DEVICE - static constexpr bool is_tuple_impl() + ROCPRIM_TRAITS_GENERATE_IS_DEFINE(float_bit_mask); + + // If this trait is defined, then use the new interface + template)> + static constexpr auto get() { - return is_tuple_impl(); + ROCPRIM_DO_NOT_COMPILE_IF( + number_format::get().value != number_format::kind::floating_point_type, + "You cannot use trait `float_bit_mask` for `non-floating_point` types"); + return typename define::float_bit_mask{}; } - template<> - ROCPRIM_HOST_DEVICE - static constexpr bool is_tuple_impl() + // For types that don't have a trait `float_bit_mask` defined + template)> + static constexpr auto get() { - return true; + ROCPRIM_DO_NOT_COMPILE_IF( + number_format::get().value != number_format::kind::floating_point_type, + "You cannot use trait `float_bit_mask` for `non-floating_point` types"); + ROCPRIM_DO_NOT_COMPILE_IF(number_format::get().value + == number_format::kind::floating_point_type, + "Trait `float_bit_mask` is required for `floating_point` types"); + return values{}; } - -public: - static constexpr bool value = is_tuple_impl<0>(); -}; - -template -struct is_tuple_of_references -{ - static_assert(sizeof(T) == 0, "is_tuple_of_references is only implemented for rocprim::tuple"); +#endif }; -template -struct is_tuple_of_references<::rocprim::tuple> +/// \par Definability +/// * **Undefinable**: For all types. +/// \par Overview This triat is auto matically generated. +/// \par How to use +/// \parblock +/// \code{.cpp} +/// constexpr auto codec = rocprim::traits::get().radix_key_codec(); +/// using codec_t = decltype(codec); +/// \endcode +/// \endparblock +struct radix_key_codec { -private: - template - ROCPRIM_HOST_DEVICE static constexpr bool is_tuple_of_references_impl() +#ifndef DOXYGEN_DOCUMENTATION_BUILD + ROCPRIM_TRAITS_GENERATE_IS_DEFINE(radix_key_codec); + template + static ROCPRIM_HOST_DEVICE + auto bit_cast(const Source& source) + -> std::enable_if_t, Destination> { - using tuple_t = ::rocprim::tuple; - using element_t = ::rocprim::tuple_element_t; - return std::is_reference::value && is_tuple_of_references_impl(); + #if defined(__has_builtin) && __has_builtin(__builtin_bit_cast) + return __builtin_bit_cast(Destination, source); + #else + static_assert(std::is_trivially_constructable::value, + "Fallback implementation of bit_cast requires Destination to be trivially " + "constructible"); + Destination dest; + memcpy(&dest, &source, sizeof(Destination)); + return dest; + #endif } + template + using get_bit_key_type = typename std::conditional< + sizeof(Key) == sizeof(char), + unsigned char, + typename std::conditional< + sizeof(Key) == sizeof(short), + unsigned short, + typename std::conditional< + sizeof(Key) == sizeof(int), + unsigned int, + typename std::conditional< + sizeof(Key) == sizeof(long long), + unsigned long long, + typename std::conditional::type>::type>::type>::type>::type; + + /// \brief Encode and decode integral and floating point values for radix sort in such a way that preserves + /// correct order of negative and positive keys (i.e. negative keys go before positive ones, + /// which is not true for a simple reinterpetation of the key's bits). + /// + /// Digit extractor takes into account that (+0.0 == -0.0) is true for floats, + /// so both +0.0 and -0.0 are reflected into the same bit pattern for digit extraction. + /// Maximum digit length is 32. + template + struct codec_base + {}; + + /// \brief For unsigned integral types + template + struct codec_base< + Key, + typename std::enable_if< + number_format::get().value == number_format::kind::integral_type + && integral_sign::get().value == integral_sign::kind::unsigned_type>::type> + { + using bit_key_type = get_bit_key_type; + + ROCPRIM_HOST_DEVICE + static bit_key_type encode(Key key) + { + return bit_cast(key); + } + ROCPRIM_HOST_DEVICE + static Key decode(bit_key_type bit_key) + { + return bit_cast(bit_key); + } + + template + ROCPRIM_HOST_DEVICE + static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } + }; + + /// \brief For signed integral types + template + struct codec_base< + Key, + typename std::enable_if< + number_format::get().value == number_format::kind::integral_type + && integral_sign::get().value == integral_sign::kind::signed_type>::type> + { + using bit_key_type = get_bit_key_type; + + static constexpr bit_key_type sign_bit = bit_key_type(1) << (sizeof(bit_key_type) * 8 - 1); + + ROCPRIM_HOST_DEVICE + static bit_key_type encode(Key key) + { + const auto bit_key = bit_cast(key); + return sign_bit ^ bit_key; + } + + ROCPRIM_HOST_DEVICE + static Key decode(bit_key_type bit_key) + { + bit_key ^= sign_bit; + return bit_cast(bit_key); + } + + template + ROCPRIM_HOST_DEVICE + static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } + }; + + /// \brief For floating point types + template + struct codec_base().value + == number_format::kind::floating_point_type>::type> + { + using bit_key_type = get_bit_key_type; + + static constexpr bit_key_type sign_bit = float_bit_mask::get().sign_bit; + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE + static bit_key_type encode(Key key) + { + bit_key_type bit_key = bit_cast(key); + bit_key ^= (sign_bit & bit_key) == 0 ? sign_bit : bit_key_type(-1); + return bit_key; + } + + ROCPRIM_HOST_DEVICE ROCPRIM_INLINE + static Key decode(bit_key_type bit_key) + { + bit_key ^= (sign_bit & bit_key) == 0 ? bit_key_type(-1) : sign_bit; + return bit_cast(bit_key); + } + + template + ROCPRIM_HOST_DEVICE + static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + + // radix_key_codec_floating::encode() maps 0.0 to 0x8000'0000, + // and -0.0 to 0x7FFF'FFFF. + // radix_key_codec::encode() then flips the bits if descending, yielding: + // value | descending | ascending | + // ----- | ----------- | ----------- | + // 0.0 | 0x7FFF'FFFF | 0x8000'0000 | + // -0.0 | 0x8000'0000 | 0x7FFF'FFFF | + // + // For ascending sort, both should be mapped to 0x8000'0000, + // and for descending sort, both should be mapped to 0x7FFF'FFFF. + if constexpr(Descending) + { + bit_key = bit_key == sign_bit ? static_cast(~sign_bit) : bit_key; + } + else + { + bit_key = bit_key == static_cast(~sign_bit) ? sign_bit : bit_key; + } + return static_cast(bit_key >> start) & mask; + } + }; + /// \brief For bool template<> - ROCPRIM_HOST_DEVICE static constexpr bool is_tuple_of_references_impl() + struct codec_base { - return true; - } + using bit_key_type = unsigned char; + + ROCPRIM_HOST_DEVICE + static bit_key_type encode(bool key) + { + return static_cast(key); + } + + ROCPRIM_HOST_DEVICE + static bool decode(bit_key_type bit_key) + { + return static_cast(bit_key); + } + + template + ROCPRIM_HOST_DEVICE + static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int length) + { + unsigned int mask = (1u << length) - 1; + return static_cast(bit_key >> start) & mask; + } + }; -public: - static constexpr bool value = is_tuple_of_references_impl<0>(); -}; + /// \brief Determines whether a type has bit_key_type + template + struct has_bit_key_type + { + template + static std::true_type check(typename U::bit_key_type*); -template -using value_type_t = typename std::iterator_traits::value_type; + template + static std::false_type check(...); -template -struct guarded_inequality_wrapper -{ - /// Wrapped equality operator - EqualityOp op; + using result = decltype(check(nullptr)); + }; + +#endif + /// \brief Determines whether the type is fundamental for `radix_key`. + template + using radix_key_fundamental = typename has_bit_key_type>::result; - /// Out-of-bounds limit - size_t guard; + /// \brief codec_base wrapper for fundamental radix key types + template::value> + class codec : protected codec_base + { + using base_type = codec_base; + + public: + /// \brief Type of the encoded key. + using bit_key_type = typename base_type::bit_key_type; + /// \brief Encodes a key of type \p Key into \p bit_key_type. + /// + /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be + /// \p identity_decomposer. This is also the type by default. + /// \param [in] key Key to encode. + /// \param [in] decomposer [optional] Decomposer functor. + /// \return A \p bit_key_type encoded key. + template + ROCPRIM_HOST_DEVICE + static bit_key_type encode(Key key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + bit_key_type bit_key = base_type::encode(key); + return Descending ? ~bit_key : bit_key; + } + + /// \brief Encodes in-place a key of type \p Key. + /// + /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be + /// \p identity_decomposer. This is also the type by default. + /// \param [in, out] key Key to encode. + /// \param [in] decomposer [optional] Decomposer functor. + template + ROCPRIM_HOST_DEVICE + static void encode_inplace(Key& key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + key = bit_cast(encode(key)); + } + + /// \brief Decodes an encoded key of type \p bit_key_type back into \p Key. + /// + /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be + /// \p identity_decomposer. This is also the type by default. + /// \param [in] bit_key Key to decode. + /// \param [in] decomposer [optional] Decomposer functor. + /// \return A \p Key decoded key. + template + ROCPRIM_HOST_DEVICE + static Key decode(bit_key_type bit_key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + bit_key = Descending ? ~bit_key : bit_key; + return base_type::decode(bit_key); + } + + /// \brief Decodes in-place an encoded key of type \p Key. + /// + /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be + /// \p identity_decomposer. This is also the type by default. + /// \param [in, out] key Key to decode. + /// \param [in] decomposer [optional] Decomposer functor. + template + ROCPRIM_HOST_DEVICE + static void decode_inplace(Key& key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + key = decode(bit_cast(key)); + } + + /// \brief Extracts the specified bits from a given encoded key. + /// + /// \param [in] bit_key Encoded key. + /// \param [in] start Start bit of the sequence of bits to extract. + /// \param [in] radix_bits How many bits to extract. + /// \return Requested bits from the key. + ROCPRIM_HOST_DEVICE + static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int radix_bits) + { + return base_type::template extract_digit(bit_key, start, radix_bits); + } + + /// \brief Extracts the specified bits from a given in-place encoded key. + /// + /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be + /// \p identity_decomposer. This is also the type by default. + /// \param [in] key Key. + /// \param [in] start Start bit of the sequence of bits to extract. + /// \param [in] radix_bits How many bits to extract. + /// \param [in] decomposer [optional] Decomposer functor. + /// \return Requested bits from the key. + template + ROCPRIM_HOST_DEVICE + static unsigned int extract_digit(Key key, + unsigned int start, + unsigned int radix_bits, + Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + return extract_digit(bit_cast(key), start, radix_bits); + } + + /// \brief Gives the default value for out-of-bound keys of type \p Key. + /// + /// \tparam Decomposer Being \p Key a fundamental type, \p Decomposer should be + /// \p identity_decomposer. This is also the type by default. + /// \param [in] decomposer [optional] Decomposer functor. + /// \return Out-of-bound keys' value. + template + ROCPRIM_HOST_DEVICE + static Key get_out_of_bounds_key(Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + return decode(static_cast(-1)); + } + }; - /// Constructor - ROCPRIM_HOST_DEVICE inline guarded_inequality_wrapper(EqualityOp op, size_t guard) - : op(op), guard(guard) - {} +#ifndef DOXYGEN_DOCUMENTATION_BUILD + template + class codec : protected codec_base + { + using base_type = codec_base; + + public: + using bit_key_type = typename base_type::bit_key_type; + + template + ROCPRIM_HOST_DEVICE + static bit_key_type encode(bool key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + return Descending != key; + } + + template + ROCPRIM_HOST_DEVICE + static void encode_inplace(bool& key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + key = bit_cast(encode(key)); + } + + template + ROCPRIM_HOST_DEVICE + static bool decode(bit_key_type bit_key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + const bool key_value = bit_key; + return Descending != key_value; + } + + template + ROCPRIM_HOST_DEVICE + static void decode_inplace(bool& key, Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + key = decode(bit_cast(key)); + } + + ROCPRIM_HOST_DEVICE + static unsigned int + extract_digit(bit_key_type bit_key, unsigned int start, unsigned int radix_bits) + { + return base_type::template extract_digit(bit_key, start, radix_bits); + } + + template + ROCPRIM_HOST_DEVICE + static unsigned int extract_digit(bool key, + unsigned int start, + unsigned int radix_bits, + Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + return extract_digit(bit_cast(key), start, radix_bits); + } + + template + ROCPRIM_HOST_DEVICE + static bool get_out_of_bounds_key(Decomposer decomposer = {}) + { + static_assert(std::is_same::value, + "Fundamental types don't use custom decomposer."); + return decode(static_cast(-1)); + } + }; +#endif - /// \brief Guarded boolean inequality operator. - /// - /// \tparam T Type of the operands compared by the equality operator - /// \param a Left hand-side operand - /// \param b Right hand-side operand - /// \param idx Index of the thread calling to this operator. This is used to determine which - /// operations are out-of-bounds - /// \returns !op(a, b) for a certain equality operator \p op when in-bounds. - template + /// \brief Specialization of `class codec` for non-fundamental radix key types + template + class codec + { + public: + /// \brief The key in this case is a custom type, so \p bit_key_type cannot be the type of the + /// encoded key because it depends on the decomposer used. It is thus set as the type \p Key. + using bit_key_type = Key; + + /// \brief Encodes a key of type \p Key into \p bit_key_type. + /// + /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer + /// type must be other than the \p identity_decomposer. + /// \param [in] key Key to encode. + /// \param [in] decomposer [optional] \p Key is a custom key type, so a custom decomposer + /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a + /// \p Key key is needed. + /// \return A \p bit_key_type encoded key. + template ROCPRIM_HOST_DEVICE - inline bool - operator()(const T& a, const T& b, size_t idx) const + static bit_key_type encode(Key key, Decomposer decomposer = {}) + { + encode_inplace(key, decomposer); + return static_cast(key); + } + + /// \brief Encodes in-place a key of type \p Key. + /// + /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer + /// type must be other than the \p identity_decomposer. + /// \param [in, out] key Key to encode. + /// \param [in] decomposer [optional] \p Key is a custom key type, so a custom decomposer + /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a + /// \p Key key is needed. + template + ROCPRIM_HOST_DEVICE + static void encode_inplace(Key& key, Decomposer decomposer = {}) + { + static_assert(!std::is_same::value, + "The decomposer of a custom-type key cannot be the identity decomposer."); + static_assert( + ::rocprim::detail::is_tuple_of_references::value, + "The decomposer must return a tuple of references."); + const auto per_element_encode = [](auto& tuple_element) + { + using element_type_t = std::decay_t; + using codec_t = codec; + codec_t::encode_inplace(tuple_element); + }; + ::rocprim::detail::for_each_in_tuple(decomposer(key), per_element_encode); + } + + /// \brief Decodes an encoded key of type \p bit_key_type back into \p Key. + /// + /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer + /// type must be other than the \p identity_decomposer. + /// \param [in] bit_key Key to decode. + /// \param [in] decomposer [optional] \p Key is a custom key type, so a custom decomposer + /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a + /// \p Key key is needed. + /// \return A \p Key decoded key. + template + ROCPRIM_HOST_DEVICE + static Key decode(bit_key_type bit_key, Decomposer decomposer = {}) + { + decode_inplace(bit_key, decomposer); + return static_cast(bit_key); + } + + /// \brief Decodes in-place an encoded key of type \p Key. + /// + /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer + /// type must be other than the \p identity_decomposer. + /// \param [in, out] key Key to decode. + /// \param [in] decomposer [optional] Decomposer functor. + template + ROCPRIM_HOST_DEVICE + static void decode_inplace(Key& key, Decomposer decomposer = {}) + { + static_assert(!std::is_same::value, + "The decomposer of a custom-type key cannot be the identity decomposer."); + static_assert( + ::rocprim::detail::is_tuple_of_references::value, + "The decomposer must return a tuple of references."); + const auto per_element_decode = [](auto& tuple_element) + { + using element_type_t = std::decay_t; + using codec_t = codec; + codec_t::decode_inplace(tuple_element); + }; + ::rocprim::detail::for_each_in_tuple(decomposer(key), per_element_decode); + } + + /// \brief Extracts the specified bits from a given encoded key. + /// + /// \return Requested bits from the key. + ROCPRIM_HOST_DEVICE + static unsigned int extract_digit(bit_key_type, unsigned int, unsigned int) + { + static_assert( + sizeof(bit_key_type) == 0, + "Only fundamental types (integral and floating point) are supported as radix sort" + "keys without a decomposer. " + "For custom key types, use the extract_digit overloads with the decomposer " + "argument"); + } + + /// \brief Extracts the specified bits from a given in-place encoded key. + /// + /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer + /// type must be other than the \p identity_decomposer. + /// \param [in] key Key. + /// \param [in] start Start bit of the sequence of bits to extract. + /// \param [in] radix_bits How many bits to extract. + /// \param [in] decomposer \p Key is a custom key type, so a custom decomposer + /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a + /// \p Key key is needed. + /// \return Requested bits from the key. + template + ROCPRIM_HOST_DEVICE + static unsigned int extract_digit(Key key, + unsigned int start, + unsigned int radix_bits, + Decomposer decomposer) + { + static_assert(!std::is_same::value, + "The decomposer of a custom-type key cannot be the identity decomposer."); + static_assert( + ::rocprim::detail::is_tuple_of_references::value, + "The decomposer must return a tuple of references."); + constexpr size_t tuple_size + = ::rocprim::tuple_size>::value; + return extract_digit_from_key_impl(0u, + decomposer(key), + static_cast(start), + static_cast(start + radix_bits), + 0); + } + + /// \brief Gives the default value for out-of-bound keys of type \p Key. + /// + /// \tparam Decomposer Decomposer functor type. Being \p Key a custom key type, the decomposer + /// type must be other than the \p identity_decomposer. + /// \param [in] decomposer \p Key is a custom key type, so a custom decomposer + /// functor that returns a \p ::rocprim::tuple of references to fundamental types from a + /// \p Key key is needed. + /// \return Out-of-bound keys' value. + template + ROCPRIM_HOST_DEVICE + static Key get_out_of_bounds_key(Decomposer decomposer) + { + static_assert(!std::is_same::value, + "The decomposer of a custom-type key cannot be the identity decomposer."); + static_assert(std::is_default_constructible::value, + "The sorted Key type must be default constructible"); + Key key; + ::rocprim::detail::for_each_in_tuple( + decomposer(key), + [](auto& element) + { + using element_t = std::decay_t; + using codec_t = codec; + using bit_key_type = typename codec_t::bit_key_type; + element = codec_t::decode(static_cast(-1)); + }); + return key; + } + + private: + template + ROCPRIM_HOST_DEVICE + static auto extract_digit_from_key_impl(unsigned int digit, + const ::rocprim::tuple& key_tuple, + const int start, + const int end, + const int previous_bits) + -> std::enable_if_t<(ElementIndex >= 0), unsigned int> + { + using T + = std::decay_t<::rocprim::tuple_element_t>>; + using bit_key_type = typename codec::bit_key_type; + constexpr int current_element_bits = 8 * sizeof(T); + + const int total_extracted_bits = end - start; + const int current_element_end_bit = previous_bits + current_element_bits; + if(start < current_element_end_bit && end > previous_bits) + { + // unsigned integral representation of the current tuple element + const auto element_value + = bit_cast(::rocprim::get(key_tuple)); + + const int bits_extracted_previously = ::rocprim::max(0, previous_bits - start); + + // start of the bit range copied from the current tuple element + const int current_start_bit = ::rocprim::max(0, start - previous_bits); + + // end of the bit range copied from the current tuple element + const int current_end_bit = ::rocprim::min(current_element_bits, + current_start_bit + total_extracted_bits + - bits_extracted_previously); + + // number of bits extracted from the current tuple element + const int current_length = current_end_bit - current_start_bit; + + // bits extracted from the current tuple element, aligned to LSB + unsigned int current_extract = (element_value >> current_start_bit); + + if(current_length != 32) + { + current_extract &= (1u << current_length) - 1; + } + + digit |= current_extract << bits_extracted_previously; + } + return extract_digit_from_key_impl(digit, + key_tuple, + start, + end, + previous_bits + + current_element_bits); + } + + /// + template + ROCPRIM_HOST_DEVICE + static auto extract_digit_from_key_impl(unsigned int digit, + const ::rocprim::tuple& /*key_tuple*/, + const int /*start*/, + const int /*end*/, + const int /*previous_bits*/) + -> std::enable_if_t<(ElementIndex < 0), unsigned int> + { + return digit; + } + }; + + /// \brief The getter of this trait + /// \tparam Key type of the radix key + /// \returns The specialization of `rocprim::traits::radix_key_codec::codec`. + template + static constexpr auto get() { - // In-bounds return operation result, out-of-bounds return ret. - return (idx < guard) ? !op(a, b) : Ret; + return codec{}; } }; -} // end namespace detail - -/// \brief Behaves like ``std::invoke_result``, but allows the use of invoke_result -/// with device-only lambdas/functors in host-only functions on HIP-clang. -/// -/// \tparam F Type of the function. -/// \tparam Args Input type(s) to the function ``F``. -template -struct invoke_result : detail::invoke_result_impl +/// \par Overview +/// This template struct is designed to allow rocPRIM algorithms to retrieve trait information from C++ +/// build-in arithmetic types, rocPRIM types, and custom types. This API is not static because of ODR. +/// * All member functions are `compiled only when invoked`. +/// * Different algorithms require different traits. +/// \tparam T The type from which you want to retrieve the traits. +/// \par Example +/// \parblock +/// The following code demonstrates how to retrieve the traits of type `T`. +/// \code{.cpp} +/// // Get the trait in a template parameter +/// template().is_integral()>::type* = nullptr> +/// void get_traits_in_template_parameter(){} +/// // Get the trait in a function body +/// template +/// void get_traits_in_function_body(){ +/// constexpr auto input_traits = rocprim::traits::get(); +/// // Then you can use the member functinos +/// constexpr bool is_arithmetic = input_traits.is_arithmetic(); +/// } +/// \endcode +/// \endparblock +template +struct get { -#ifdef DOXYGEN_DOCUMENTATION_BUILD - /// \brief The return type of the Callable type F if invoked with the arguments Args. - /// \hideinitializer - using type = detail::invoke_result_impl::type; -#endif // DOXYGEN_DOCUMENTATION_BUILD -}; + /// \brief Get the value of trait `is_arithmetic`. + /// \returns `true` if `std::is_arithmetic_v` is `true`, or if type `T` is a rocPRIM arithmetic + /// type, or if the `is_arithmetic` trait has been defined as `true`; otherwise, returns `false`. + constexpr bool is_arithmetic() const + { + return rocprim::traits::is_arithmetic{}.get().value; + }; -/// \brief Helper type. It is an alias for ``invoke_result::type``. -/// -/// \tparam F Type of the function. -/// \tparam Args Input type(s) to the function ``F``. -template -using invoke_result_t = typename invoke_result::type; + /// \brief Get trait `is_fundamental`. + /// \returns `true` if `T` is a fundamental type (that is, rocPRIM arithmetic type, void, or nullptr_t); + /// otherwise, returns `false`. + constexpr bool is_fundamental() const + { + return std::is_fundamental::value || rocprim::traits::is_arithmetic{}.get().value; + }; -/// \brief Utility wrapper around ``invoke_result`` for binary operators. -/// -/// \tparam T Input type to the binary operator. -/// \tparam F Type of the binary operator. -template -struct invoke_result_binary_op -{ - /// \brief The result type of the binary operator. - using type = typename invoke_result::type; -}; + /// \brief Check if the type is a `build_in` type, this function is different from `is_fundamental`, + /// because, by implementing traits, downstream code can "hack" into rocprim to let a type be `arithmetic`, + /// and by following the rules of `std::is_fundamental`, `rocprim::is_fundamental` returns a union set of + /// `std::is_fundamental` and `rocprim::is_arithmetic`. So, to check wether a type is a build-in type, + /// please use this function. + /// \returns `true` if `T` is a `build_in` type (that is, char, unsigned char, short, unsigned short, int + /// unsigned int, long long, unsigned long long, rocprim::int128_t, rocprim::uint128_t, rocprim::half, + /// float, double); + constexpr bool is_build_in() const + { + return std::is_same::value || std::is_same::value + || std::is_same::value || std::is_same::value + || std::is_same::value || std::is_same::value + || std::is_same::value || std::is_same::value + || std::is_same::value + || std::is_same::value + || std::is_same::value + || std::is_same::value || std::is_same::value + || std::is_same::value || std::is_same::value; + } -/// \brief Helper type. It is an alias for ``invoke_result_binary_op::type``. -/// -/// \tparam T Input type to the binary operator. -/// \tparam F Type of the binary operator. -template -using invoke_result_binary_op_t = typename invoke_result_binary_op::type; + /// \brief If `T` is fundamental type, then returns `false`. + /// \returns `false` if `T` is a fundamental type (that is, rocPRIM arithmetic type, void, or nullptr_t); + /// otherwise, returns `true`. + constexpr bool is_compound() const + { + return !is_fundamental(); + } -namespace detail -{ + /// \brief To check if `T` is floating-point type. + /// \warning You cannot call this function when `is_arithmetic()` returns `false`; + /// doing so will result in a compile-time error. + constexpr bool is_floating_point() const + { + return rocprim::traits::number_format{}.get().value + == number_format::kind::floating_point_type; + }; -/// \brief If `T` is a rocPRIM binary functional type, provides the member constant `value` equal `true`. -/// For any other type, `value` is `false`. -template -struct is_binary_functional -{ - static constexpr bool value = false; -}; + /// \brief To check if `T` is integral type. + /// \warning You cannot call this function when `is_arithmetic()` returns `false`; + /// doing so will result in a compile-time error. + constexpr bool is_integral() const + { + return rocprim::traits::number_format{}.get().value + == number_format::kind::integral_type; + } -template -struct is_binary_functional> -{ - static constexpr bool value = true; -}; + /// \brief To check if `T` is signed integral type. + /// \warning You cannot call this function when `is_integral()` returns `false`; + /// doing so will result in a compile-time error. + constexpr bool is_signed() const + { + return rocprim::traits::integral_sign{}.get().value == integral_sign::kind::signed_type; + } -template -struct is_binary_functional> -{ - static constexpr bool value = true; -}; + /// \brief To check if `T` is unsigned integral type. + /// \warning You cannot call this function when `is_integral()` returns `false`; + /// doing so will result in a compile-time error. + constexpr bool is_unsigned() const + { + return rocprim::traits::integral_sign{}.get().value + == integral_sign::kind::unsigned_type; + } -template -struct is_binary_functional> -{ - static constexpr bool value = true; -}; + /// \brief Get trait `is_scalar`. + /// \returns `true` if `std::is_scalar_v` is `true`, or if type `T` is a rocPRIM arithmetic + /// type, or if the `is_scalar` trait has been defined as `true`; otherwise, returns `false`. + constexpr bool is_scalar() const + { + return rocprim::traits::is_scalar{}.get().value; + } -template -struct is_binary_functional> -{ - static constexpr bool value = true; + /// \brief Get trait `float_bit_mask`. + /// \warning You cannot call this function when `is_floating_point()` returns `false`; + /// doing so will result in a compile-time error. + /// \returns A constexpr instance of the specialization of `rocprim::traits::float_bit_mask::values` + /// as provided in the traits definition of type T. If the `float_bit_mask trait` is not defined, it + /// returns the rocprim::detail::float_bit_mask values, provided a specialization of + /// `rocprim::detail::float_bit_mask` exists. + constexpr auto float_bit_mask() const + { + return rocprim::traits::float_bit_mask{}.get(); + }; + + /// \brief Get trait `radix_key_codec`. + /// \returns A constexpr instance of the specialization of `rocprim::traits::radix_key_codec::codec` + /// as provided in the traits definition of type T. + template + constexpr auto radix_key_codec() const + { + return rocprim::traits::radix_key_codec{}.get(); + } }; -template -struct is_binary_functional> +} // namespace traits + +/// \defgroup rocprim_pre_defined_traits Trait definitions for rocPRIM arithmetic types and additional traits for +/// C++ build-in arithmetic types. +/// \addtogroup rocprim_pre_defined_traits +/// @{ + +/// \brief This is the definition of traits of `float` +/// C++ build-in type +template<> +struct traits::define { - static constexpr bool value = true; + /// \brief Trait `float_bit_mask` for this type + using float_bit_mask + = traits::float_bit_mask::values; }; -template -struct is_binary_functional> +/// \brief This is the definition of traits of `double` +/// C++ build-in type +template<> +struct traits::define { - static constexpr bool value = true; + /// \brief Trait `float_bit_mask` for this type + using float_bit_mask = traits::float_bit_mask:: + values; }; -template -struct is_binary_functional> +/// \brief This is the definition of traits of `rocprim::bfloat16` +/// rocPRIM arithmetic type +template<> +struct traits::define { - static constexpr bool value = true; + /// \brief Trait `is_arithmetic` for this type + using is_arithmetic = traits::is_arithmetic::values; + /// \brief Trait `number_format` for this type + using number_format + = traits::number_format::values; + /// \brief Trait `float_bit_mask` for this type + using float_bit_mask = traits::float_bit_mask::values; }; -template -struct is_binary_functional> +/// \brief This is the definition of traits of `rocprim::half` +/// rocPRIM arithmetic type +template<> +struct traits::define { - static constexpr bool value = true; + /// \brief Trait `is_arithmetic` for this type + using is_arithmetic = traits::is_arithmetic::values; + /// \brief Trait `number_format` for this type + using number_format + = traits::number_format::values; + /// \brief Trait `float_bit_mask` for this type + using float_bit_mask = traits::float_bit_mask::values; }; -template -struct is_binary_functional> +// Type traits like std::is_integral and std::is_arithmetic may be defined for 128-bit integral +// types (__int128_t and __uint128_t) in several cases: +// * with libstdc++ when GNU extensions are enabled (-std=gnu++17, which is the default C++ +// standard in clang); +// * always with libc++ (it is used on HIP SDK for Windows). + +namespace detail { - static constexpr bool value = true; -}; -template -struct is_binary_functional> +struct define_int128_t { - static constexpr bool value = true; + /// \brief Trait `is_arithmetic` for this type + using is_arithmetic = traits::is_arithmetic::values; + /// \brief Trait `number_format` for this type + using number_format = traits::number_format::values; + /// \brief Trait `integral_sign` for this type + using integral_sign = traits::integral_sign::values; }; -template -struct is_binary_functional> +struct define_uint128_t { - static constexpr bool value = true; + /// \brief Trait `is_arithmetic` for this type + using is_arithmetic = traits::is_arithmetic::values; + /// \brief Trait `number_format` for this type + using number_format = traits::number_format::values; + /// \brief Trait `integral_sign` for this type + using integral_sign = traits::integral_sign::values; }; } // namespace detail -/// \brief Helper struct it has the Type and the number of aligned bytes. -/// -/// \tparam T is the Type used to get the number of aligned bytes. -template -struct align_bytes -{ - /// Number of aligned bytes for type T - static constexpr unsigned value = alignof(T); - /// Type defined by T - using Type = T; -}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specialized versions -template -struct align_bytes : align_bytes +/// \brief This is the definition of traits of `rocprim::int128_t` +/// rocPRIM arithmetic type +template<> +struct traits::define + : std::conditional_t::value, + traits::define, + detail::define_int128_t> {}; -template -struct align_bytes : align_bytes + +/// \brief This is the definition of traits of `rocprim::uint128_t` +/// rocPRIM arithmetic type +template<> +struct traits::define + : std::conditional_t::value, + traits::define, + detail::define_uint128_t> {}; -template -struct align_bytes : align_bytes + +/// @} + +/// \brief An extension of `std::is_floating_point` that supports additional arithmetic types, +/// including `rocprim::half`, `rocprim::bfloat16`, and any types with trait +/// `rocprim::traits::number_format::values` implemented. +template +struct is_floating_point + : std::integral_constant().is_floating_point()> {}; -#endif -namespace detail -{ +/// \brief An extension of `std::is_integral` that supports additional arithmetic types, +/// including `rocprim::int128_t`, `rocprim::uint128_t`, and any types with trait +/// `rocprim::traits::number_format::values` implemented. +template +struct is_integral : std::integral_constant().is_integral()> +{}; -template -struct word_type -{ - static constexpr auto align_bytes_value = align_bytes::value; +/// \brief An extension of `std::is_arithmetic` that supports additional arithmetic types, +/// including any types with trait `rocprim::traits::is_arithmetic::values` implemented. +template +struct is_arithmetic : std::integral_constant().is_arithmetic()> +{}; - template - struct IsMultiple - { - static constexpr auto unit_align_bytes = align_bytes::value; - static constexpr bool is_multiple - = (sizeof(T) % sizeof(Unit) == 0) - && (int(align_bytes_value) % int(unit_align_bytes) == 0); - }; +/// \brief An extension of `std::is_fundamental` that supports additional arithmetic types, +/// including any types with trait `rocprim::traits::is_arithmetic::values` implemented. +template +struct is_fundamental : std::integral_constant().is_fundamental()> +{}; - using type = typename std::conditional::is_multiple, - unsigned int, - typename std::conditional::is_multiple, - unsigned short, - unsigned char>::type>::type; -}; +/// \brief An extension of `std::is_unsigned` that supports additional arithmetic types, +/// including `rocprim::uint128_t`, and any types with trait +/// `rocprim::traits::integral_sign::values` implemented. +template +struct is_unsigned : std::integral_constant().is_unsigned()> +{}; -template -struct word_type : word_type +/// \brief An extension of `std::is_signed` that supports additional arithmetic types, +/// including `rocprim::int128_t`, and any types with trait +/// `rocprim::traits::integral_sign::values` implemented. +template +struct is_signed : std::integral_constant().is_signed()> {}; -template -struct word_type : word_type + +/// \brief An extension of `std::is_scalar` that supports additional arithmetic types, +/// including any types with trait `rocprim::traits::is_scalar::values` implemented. +template +struct is_scalar : std::integral_constant().is_scalar()> {}; -template -struct word_type : word_type + +/// \brief An extension of `std::is_scalar` that supports additional non-arithmetic types. +template +struct is_compound : std::integral_constant().is_compound()> {}; -} // namespace detail +static_assert(::rocprim::traits::radix_key_codec::radix_key_fundamental::value, + "'int' should be fundamental"); +static_assert(!::rocprim::traits::radix_key_codec::radix_key_fundamental::value, + "'int*' should not be fundamental"); +static_assert(::rocprim::traits::radix_key_codec::radix_key_fundamental::value, + "'rocprim::int128_t' should be fundamental"); +static_assert(::rocprim::traits::radix_key_codec::radix_key_fundamental::value, + "'rocprim::uint128_t' should be fundamental"); +static_assert(!::rocprim::traits::radix_key_codec::radix_key_fundamental::value, + "'rocprim::int128_t*' should not be fundamental"); END_ROCPRIM_NAMESPACE -/// @} -// end of group utilsmodule_typetraits - -#endif // ROCPRIM_TYPE_TRAITS_HPP_ +#endif diff --git a/rocprim/include/rocprim/type_traits_functions.hpp b/rocprim/include/rocprim/type_traits_functions.hpp new file mode 100644 index 000000000..3edf8968e --- /dev/null +++ b/rocprim/include/rocprim/type_traits_functions.hpp @@ -0,0 +1,561 @@ +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_TYPE_TRAITS_FUNCTIONS_HPP_ +#define ROCPRIM_TYPE_TRAITS_FUNCTIONS_HPP_ + +#include "config.hpp" +#include "functional.hpp" // not used +#include "types.hpp" +#include "types/integer_sequence.hpp" +#include "types/tuple.hpp" + +#include +#include +#include +#include +#include +#include + +/// \addtogroup utilsmodule_typetraits +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +using void_t = void; + +} // namespace detail + +/// \brief Extension of `std::make_unsigned`, which includes support for 128-bit integers. +template +struct make_unsigned : std::make_unsigned +{}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specialized versions +template<> +struct make_unsigned<::rocprim::int128_t> +{ + using type = ::rocprim::uint128_t; +}; + +template<> +struct make_unsigned<::rocprim::uint128_t> +{ + using type = ::rocprim::uint128_t; +}; +#endif + +static_assert(std::is_same::type, ::rocprim::uint128_t>::value, + "'rocprim::int128_t' needs to implement 'make_unsigned' trait."); + +/// \brief Extension of `std::numeric_limits`, which includes support for 128-bit integers. +template +struct numeric_limits : std::numeric_limits +{}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specialized versions +template<> +struct numeric_limits : std::numeric_limits +{ + static constexpr int digits = 128; + static constexpr int digits10 = 38; + + static constexpr rocprim::uint128_t max() + { + return rocprim::int128_t{-1}; + } + + static constexpr rocprim::uint128_t min() + { + return rocprim::uint128_t{0}; + } + + static constexpr rocprim::uint128_t lowest() + { + return min(); + } +}; + +template<> +struct numeric_limits : std::numeric_limits +{ + static constexpr int digits = 127; + static constexpr int digits10 = 38; + + static constexpr rocprim::int128_t max() + { + return numeric_limits::max() >> 1; + } + + static constexpr rocprim::int128_t min() + { + return -numeric_limits::max() - 1; + } + + static constexpr rocprim::int128_t lowest() + { + return min(); + } +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + +/// \brief Used to retrieve a type that can be treated as unsigned version of the template parameter. +/// \tparam T The signed type to find an unsigned equivalent for. +/// \tparam size the desired size (in bytes) of the unsigned type +template +struct get_unsigned_bits_type +{ + using unsigned_type = typename get_unsigned_bits_type:: + unsigned_type; ///< Typedefed to the unsigned type. +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specialized versions +template +struct get_unsigned_bits_type +{ + using unsigned_type = uint8_t; +}; + +template +struct get_unsigned_bits_type +{ + using unsigned_type = uint16_t; +}; + +template +struct get_unsigned_bits_type +{ + using unsigned_type = uint32_t; +}; + +template +struct get_unsigned_bits_type +{ + using unsigned_type = uint64_t; +}; + +template +struct get_unsigned_bits_type +{ + using unsigned_type = ::rocprim::uint128_t; +}; +#endif // DOXYGEN_SHOULD_SKIP_THIS + +namespace detail +{ + +// invoke_result is based on std::invoke_result. +// The main difference is using ROCPRIM_HOST_DEVICE, this allows to +// use invoke_result with device-only lambdas/functors in host-only functions +// on HIP-clang. + +template +struct is_reference_wrapper : std::false_type +{}; +template +struct is_reference_wrapper> : std::true_type +{}; + +template +struct invoke_impl +{ + template + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE + static auto call(F&& f, Args&&... args) + -> decltype(::std::forward(f)(::std::forward(args)...)); +}; + +template +struct invoke_impl +{ + template, ::std::enable_if_t<::std::is_base_of_v>> + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE + constexpr auto get(T&& t) -> T&&; + + template, + ::std::enable_if_t::value>> + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE + constexpr auto get(T&& t) -> decltype(t.get()); + + template, + ::std::enable_if_t>, + ::std::enable_if_t::value>> + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE + constexpr auto get(T&& t) -> decltype(*::std::forward(t)); + + template>> + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE + static auto call(MT1 B::*pmf, T&& t, Args&&... args) + -> decltype((invoke_impl::get(::std::forward(t)).*pmf)(::std::forward(args)...)); + + template + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE + static auto call(MT B::*pmd, T&& t) -> decltype(invoke_impl::get(::std::forward(t)).*pmd); +}; + +template> +ROCPRIM_INLINE ROCPRIM_HOST_DEVICE +constexpr auto INVOKE(F&& f, Args&&... args) + -> decltype(invoke_impl::call(::std::forward(f), ::std::forward(args)...)); + +template +struct invoke_result_r_impl +{}; + +template +struct invoke_result_r_impl(), ::std::declval()...))), + void, + F, + Args...> +{ + using type = decltype(INVOKE(::std::declval(), ::std::declval()...)); +}; + +template +struct invoke_result_r_impl< + ::std::enable_if_t< + ::std::is_convertible_v(), ::std::declval()...)), + R>, + void>, + R, + F, + Args...> +{ + + using type = decltype(INVOKE(::std::declval(), ::std::declval()...)); +}; + +template +struct is_tuple +{ + static constexpr bool value = false; +}; + +template +struct is_tuple<::rocprim::tuple> +{ + static constexpr bool value = true; +}; + +template +struct is_tuple_of_references +{ + static_assert(sizeof(T) == 0, "is_tuple_of_references is only implemented for rocprim::tuple"); +}; + +template +struct is_tuple_of_references<::rocprim::tuple> +{ +private: + template + ROCPRIM_HOST_DEVICE + static constexpr + typename std::enable_if<(Index < sizeof...(Args)), bool>::type is_tuple_of_references_impl() + { + using tuple_t = ::rocprim::tuple; + using element_t = ::rocprim::tuple_element_t; + return std::is_reference::value && is_tuple_of_references_impl(); + } + + template + ROCPRIM_HOST_DEVICE + static constexpr typename std::enable_if<(Index == sizeof...(Args)), bool>::type + is_tuple_of_references_impl() + { + return true; + } + +public: + static constexpr bool value = is_tuple_of_references::is_tuple_of_references_impl<0>(); +}; + +template +ROCPRIM_HOST_DEVICE +inline void for_each_in_tuple_impl(Tuple&& t, Function&& f, ::rocprim::index_sequence) +{ + int swallow[] + = {(std::forward(f)(::rocprim::get(std::forward(t))), 0)...}; + (void)swallow; +} + +template +ROCPRIM_HOST_DEVICE +inline auto for_each_in_tuple(Tuple&& t, Function&& f) + -> void_t>> +{ + static constexpr size_t size = tuple_size>::value; + for_each_in_tuple_impl(std::forward(t), + std::forward(f), + ::rocprim::make_index_sequence()); +} + +template +using value_type_t = typename std::iterator_traits::value_type; + +template +struct guarded_inequality_wrapper +{ + /// Wrapped equality operator + EqualityOp op; + + /// Out-of-bounds limit + size_t guard; + + /// Constructor + ROCPRIM_HOST_DEVICE inline guarded_inequality_wrapper(EqualityOp op, size_t guard) + : op(op), guard(guard) + {} + + /// \brief Guarded boolean inequality operator. + /// + /// \tparam T Type of the operands compared by the equality operator + /// \param a Left hand-side operand + /// \param b Right hand-side operand + /// \param idx Index of the thread calling to this operator. This is used to determine which + /// operations are out-of-bounds + /// \returns !op(a, b) for a certain equality operator \p op when in-bounds. + template + ROCPRIM_HOST_DEVICE + inline bool + operator()(const T& a, const T& b, size_t idx) const + { + // In-bounds return operation result, out-of-bounds return ret. + return (idx < guard) ? !op(a, b) : Ret; + } +}; + +} // end namespace detail + +/// \brief Similar to ``rocprim::invoke_result``, but also checks if the result +/// can be converted to a specified return type when the return type is not ``void``. +/// +/// \tparam R The type to which the function's return type must be convertible. +/// \tparam F Type of the function. +/// \tparam Args Input type(s) to the function ``F``. +template +struct invoke_result_r : detail::invoke_result_r_impl +{ +#ifdef DOXYGEN_DOCUMENTATION_BUILD + /// \brief The return type of the Callable type F if invoked with the arguments Args. + /// \hideinitializer + using type = detail::invoke_result_r_impl::type; +#endif // DOXYGEN_DOCUMENTATION_BUILD +}; + +/// \brief Behaves like ``std::invoke_result``, but allows the use of invoke_result +/// with device-only lambdas/functors in host-only functions on HIP-clang. +/// +/// \tparam F Type of the function. +/// \tparam Args Input type(s) to the function ``F``. +template +using invoke_result = invoke_result_r; + +/// \brief Helper type. It is an alias for ``invoke_result::type``. +/// +/// \tparam F Type of the function. +/// \tparam Args Input type(s) to the function ``F``. +template +using invoke_result_t = typename invoke_result::type; + +/// \brief Utility wrapper around ``invoke_result`` for binary operators. +/// +/// \tparam T Input type to the binary operator. +/// \tparam F Type of the binary operator. +template +struct [[deprecated("To deduce the type of accumulator, use 'rocprim::accumulator_t' " + "instead.")]] invoke_result_binary_op +{ + /// \brief The result type of the binary operator. + using type = typename invoke_result::type; +}; + +/// \brief Helper type. It is an alias for ``invoke_result_binary_op::type``. +/// +/// \tparam T Input type to the binary operator. +/// \tparam F Type of the binary operator. +template +using invoke_result_binary_op_t + [[deprecated("To deduce the type of accumulator, use 'rocprim::accumulator_t' instead.")]] + = typename invoke_result_binary_op::type; + +/// \brief The type of intermediate accumulator (according to CCCL) +template +using accumulator_t = ::std::decay_t>; + +namespace detail +{ + +/// \brief If `T` is a rocPRIM binary functional type, provides the member constant `value` equal `true`. +/// For any other type, `value` is `false`. +template +struct is_binary_functional +{ + static constexpr bool value = false; +}; + +template +struct is_binary_functional> +{ + static constexpr bool value = true; +}; + +template +struct is_binary_functional> +{ + static constexpr bool value = true; +}; + +template +struct is_binary_functional> +{ + static constexpr bool value = true; +}; + +template +struct is_binary_functional> +{ + static constexpr bool value = true; +}; + +template +struct is_binary_functional> +{ + static constexpr bool value = true; +}; + +template +struct is_binary_functional> +{ + static constexpr bool value = true; +}; + +template +struct is_binary_functional> +{ + static constexpr bool value = true; +}; + +template +struct is_binary_functional> +{ + static constexpr bool value = true; +}; + +template +struct is_binary_functional> +{ + static constexpr bool value = true; +}; + +template +struct is_binary_functional> +{ + static constexpr bool value = true; +}; + +template +struct is_binary_functional> +{ + static constexpr bool value = true; +}; + +} // namespace detail + +/// \brief Helper struct it has the Type and the number of aligned bytes. +/// +/// \tparam T is the Type used to get the number of aligned bytes. +template +struct align_bytes +{ + /// Number of aligned bytes for type T + static constexpr unsigned value = alignof(T); + /// Type defined by T + using Type = T; +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specialized versions +template +struct align_bytes : align_bytes +{}; +template +struct align_bytes : align_bytes +{}; +template +struct align_bytes : align_bytes +{}; +#endif + +namespace detail +{ + +template +struct word_type +{ + static constexpr auto align_bytes_value = align_bytes::value; + + template + struct IsMultiple + { + static constexpr auto unit_align_bytes = align_bytes::value; + static constexpr bool is_multiple + = (sizeof(T) % sizeof(Unit) == 0) + && (int(align_bytes_value) % int(unit_align_bytes) == 0); + }; + + using type = typename std::conditional::is_multiple, + unsigned int, + typename std::conditional::is_multiple, + unsigned short, + unsigned char>::type>::type; +}; + +template +struct word_type : word_type +{}; +template +struct word_type : word_type +{}; +template +struct word_type : word_type +{}; + +} // namespace detail + +namespace detail +{ +template +constexpr bool is_valid_bit_cast + = sizeof(Destination) == sizeof(Source) && std::is_trivially_copyable::value + && std::is_trivially_copyable::value; +} // namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group utilsmodule_typetraits + +#endif // ROCPRIM_TYPE_TRAITS_HPP_ diff --git a/rocprim/include/rocprim/type_traits_interface.hpp b/rocprim/include/rocprim/type_traits_interface.hpp deleted file mode 100644 index df5c2cfd3..000000000 --- a/rocprim/include/rocprim/type_traits_interface.hpp +++ /dev/null @@ -1,797 +0,0 @@ -// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#ifndef ROCPRIM_TYPE_TRAITS_INTERFACE_HPP_ -#define ROCPRIM_TYPE_TRAITS_INTERFACE_HPP_ - -#include "types.hpp" - -#include - -// common macros - -/// \brief A reverse version of static_assert aims to increase code readability -#ifndef ROCPRIM_DO_NOT_COMPILE_IF - #define ROCPRIM_DO_NOT_COMPILE_IF(condition, msg) static_assert(!(condition), msg) -#endif -/// \brief Wrapper macro for std::enable_if aims to increase code readability -#ifndef ROCPRIM_REQUIRES - #define ROCPRIM_REQUIRES(...) typename std::enable_if<(__VA_ARGS__)>::type* = nullptr -#endif -#ifndef DOXYGEN_DOCUMENTATION_BUILD - /// \brief Since every definable traits need to use `is_defined`, this macro reduce the amount of code - #define ROCPRIM_TRAITS_GENERATE_IS_DEFINE(traits_name) \ - template \ - static constexpr bool is_defined = false; \ - template \ - static constexpr bool \ - is_defined::traits_name>> \ - = true -#endif - -BEGIN_ROCPRIM_NAMESPACE - -namespace detail -{ - -template -using void_t = void; - -template -struct [[deprecated]] float_bit_mask; - -} // namespace detail - -namespace traits -{ -/// \defgroup type_traits_interfaces Interfaces for defining and obtaining the traits -/// \addtogroup type_traits_interfaces -/// @{ - -/// \par Overview -/// This template struct provides an interface for downstream libraries to implement type traits for -/// their custom types. Users can utilize this template struct to define traits for these types. Users -/// should only implement traits as required by specific algorithms, and some traits cannot be defined -/// if they can be inferred from others. This API is not static because of ODR. -/// \tparam T The type for which you want to define traits. -/// -/// \par Example -/// \parblock -/// The example below demonstrates how to implement traits for a custom floating-point type. -/// \code{.cpp} -/// // Your type definition -/// struct custom_float_type -/// {}; -/// // Implement the traits -/// template<> -/// struct rocprim::traits::define -/// { -/// using is_arithmetic = rocprim::traits::is_arithmetic::values; -/// using number_format = rocprim::traits::number_format::values; -/// using float_bit_mask = rocprim::traits::float_bit_mask::values; -/// }; -/// \endcode -/// The example below demonstrates how to implement traits for a custom integral type. -/// \code{.cpp} -/// // Your type definition -/// struct custom_int_type -/// {}; -/// // Implement the traits -/// template<> -/// struct rocprim::traits::define -/// { -/// using is_arithmetic = rocprim::traits::is_arithmetic::values; -/// using number_format = rocprim::traits::number_format::values; -/// using integral_sign = rocprim::traits::integral_sign::values; -/// }; -/// \endcode -/// \endparblock -template -struct define -{}; - -/// @} - -/// predef -template -struct get; - -/// \defgroup available_traits Traits that can be used -/// \addtogroup available_traits -/// @{ - -/// \par Definability -/// * **Undefinable**: For types with `predefined traits`. -/// * **Optional**: For other types. -/// \par How to define -/// \parblock -/// \code{.cpp} -/// using is_arithmetic = rocprim::traits::is_arithmetic::values; -/// \endcode -/// \endparblock -/// \par How to use -/// \parblock -/// \code{.cpp} -/// rocprim::traits::get().is_arithmetic(); -/// \endcode -/// \endparblock -struct is_arithmetic -{ - /// \brief Value of this trait - template - struct values - { - /// \brief This indicates if the `InputType` is arithmetic. - static constexpr auto value = Val; - }; - -#ifndef DOXYGEN_DOCUMENTATION_BUILD - - ROCPRIM_TRAITS_GENERATE_IS_DEFINE(is_arithmetic); - - // For c++ arithmetic types, return true, but will throw compile error when user try to define this trait for them - template::value)> - static constexpr auto get() - { - ROCPRIM_DO_NOT_COMPILE_IF(is_defined, - "Do not define trait `is_arithmetic` for c++ arithmetic types"); - return values{}; - } - - // For third party types, if trait `is_arithmetic` not defined, will return default value `false` - template::value && !is_defined)> - static constexpr auto get() - { - return values{}; - } - - // For third party types, if trait `is_arithmetic` is defined, then should return its value - template::value && is_defined)> - static constexpr auto get() - { - return typename define::is_arithmetic{}; - } -#endif -}; - -/// \brief Arithmetic types, pointers, member pointers, and null pointers are considered scalar types. -/// \par Definability -/// * **Undefinable**: For types with `predefined traits`. -/// * **Optional**: For other types. If both `is_arithmetic` and `is_scalar` are defined, their values -/// must be consistent; otherwise, a compile-time error will occur. -/// \par How to define -/// \parblock -/// \code{.cpp} -/// using is_scalar = rocprim::traits::is_scalar::values; -/// \endcode -/// \endparblock -/// \par How to use -/// \parblock -/// \code{.cpp} -/// rocprim::traits::get().is_scalar(); -/// \endcode -/// \endparblock -struct is_scalar -{ - /// \brief Value of this trait - template - struct values - { - /// \brief This indicates if the `InputType` is scalar. - static constexpr auto value = Val; - }; - -#ifndef DOXYGEN_DOCUMENTATION_BUILD - - ROCPRIM_TRAITS_GENERATE_IS_DEFINE(is_scalar); - - // For c++ scalar types, return true, but will throw compile error when user try to define this trait for them - template::value)> - static constexpr auto get() - { - ROCPRIM_DO_NOT_COMPILE_IF(is_defined, - "Do not define trait `is_scalar` for c++ scalar types"); - return values{}; - } - - // For third party types, if trait `is_scalar` is not defined, will return default value `false` - // For rocprim or third party types that defined trait `is_arithmetic` as true the result should be `true` - template::value && !is_defined)> - static constexpr auto get() - { - return values().value>{}; - } - // For third party types and rocprim types, if trait `is_scalar` is defined, will return the value - // check if the `is_scalar` equals to `is_arithmetic`, or throw a compile error - template::value && is_defined)> - static constexpr auto get() - { - ROCPRIM_DO_NOT_COMPILE_IF( - is_arithmetic::get().value != typename define::is_scalar{}.value, - "Trait `is_arithmetic` and trait `is_scalar` should have the same value"); - return typename define::is_scalar{}; - } -#endif -}; - -/// \par Definability -/// * **Undefinable**: For types with `predefined traits` and non-arithmetic types. -/// * **Required**: If you define `is_arithmetic` as `true`, you must also define this trait; otherwise, a -/// compile-time error will occur. -/// \par How to define -/// \parblock -/// \code{.cpp} -/// using number_format = rocprim::traits::number_format::values; -/// \endcode -/// \endparblock -/// \par How to use -/// \parblock -/// \code{.cpp} -/// rocprim::traits::get().is_integral(); -/// rocprim::traits::get().is_floating_point(); -/// \endcode -/// \endparblock -struct number_format -{ - /// \brief The kind enum that indecates the values avaliable for this trait - enum class kind - { - unknown_type = 0, - floating_point_type = 1, - integral_type = 2 - }; - - /// \brief Value of this trait - template - struct values - { - /// \brief This indicates if the `InputType` is floating_point_type or integral_type or unknown_type. - static constexpr auto value = Val; - }; - -#ifndef DOXYGEN_DOCUMENTATION_BUILD - - ROCPRIM_TRAITS_GENERATE_IS_DEFINE(number_format); - - // For c++ arithmetic types - template::value)> - static constexpr auto get() - { // C++ build-in arithmetic types are either floating point or integral - return values < std::is_floating_point::value ? kind::floating_point_type - : kind::integral_type > {}; - } - - // For rocprim arithmetic types - template::value - && is_arithmetic::get().value)> - static constexpr auto get() - { - ROCPRIM_DO_NOT_COMPILE_IF(!is_defined, - "You must define trait `number_format` for arithmetic types"); - return typename define::number_format{}; - } - - // For other types - template::value - && !is_arithmetic::get().value)> - static constexpr auto get() - { - ROCPRIM_DO_NOT_COMPILE_IF( - is_defined, - "You cannot define trait `number_format` for non-arithmetic types"); - return values{}; - } -#endif -}; - -/// \par Definability -/// * **Undefinable**: For types with `predefined traits`, non-arithmetic types and floating-point types. -/// * **Required**: If you define `number_format` as `number_format::kind::floating_point_type`, you must also define this trait; otherwise, a -/// compile-time error will occur. -/// \par How to define -/// \parblock -/// \code{.cpp} -/// using integral_sign = rocprim::traits::integral_sign::values; -/// \endcode -/// \endparblock -/// \par How to use -/// \parblock -/// \code{.cpp} -/// rocprim::traits::get().is_signed(); -/// rocprim::traits::get().is_unsigned(); -/// \endcode -/// \endparblock -struct integral_sign -{ - /// \brief The kind enum that indecates the values avaliable for this trait - enum class kind - { - unknown_type = 0, - signed_type = 1, - unsigned_type = 2 - }; - - /// \brief Value of this trait - template - struct values - { - /// \brief This indicates if the `InputType` is signed_type or unsigned_type or unknown_type. - static constexpr auto value = Val; - }; - -#ifndef DOXYGEN_DOCUMENTATION_BUILD - - ROCPRIM_TRAITS_GENERATE_IS_DEFINE(integral_sign); - - // For c++ arithmetic types - template::value)> - static constexpr auto get() - { // cpp arithmetic types are either signed point or unsignned - return values < std::is_signed::value ? kind::signed_type - : kind::unsigned_type > {}; - } - - // For rocprim arithmetic integral - template::value && is_arithmetic::get().value - && number_format::get().value == number_format::kind::integral_type)> - static constexpr auto get() - { - ROCPRIM_DO_NOT_COMPILE_IF(!is_defined, - "Trait `integral_sign` is required for arithmetic " - "integral types, please define"); - return typename define::integral_sign{}; - } - - // For rocprim arithmetic non-integral - template::value && is_arithmetic::get().value - && number_format::get().value != number_format::kind::integral_type)> - static constexpr auto get() - { - ROCPRIM_DO_NOT_COMPILE_IF( - is_defined, - "You cannot define trait `integral_sign` for arithmetic non-integral types"); - return values{}; - } - - // For other types - template::value - && !is_arithmetic::get().value)> - static constexpr auto get() - { // For other types, trait is_floating_point is a must - ROCPRIM_DO_NOT_COMPILE_IF( - is_defined, - "You cannot define trait `integral_sign` for non-arithmetic types"); - return values{}; - } -#endif -}; - -/// \warning For some types, if this trait is not implemented in their traits definition, it will -/// link to `rocprim::detail::float_bit_mask` to maintain compatibility with downstream libraries. -/// However, this linkage will be removed in the next major release. Please ensure that these types -/// are updated to the latest interface. -/// \par Definability -/// * **Undefinable**: For types with `predefined traits`, non-arithmetic types and integral types. -/// * **Required**: If you define `number_format` as `number_format::kind::unknown_type`, you must also define this trait; otherwise, a -/// compile-time error will occur. -/// \par How to define -/// \parblock -/// \code{.cpp} -/// using float_bit_mask = rocprim::traits::float_bit_mask::values; -/// \endcode -/// \endparblock -/// \par How to use -/// \parblock -/// \code{.cpp} -/// rocprim::traits::get().float_bit_mask(); -/// \endcode -/// \endparblock -struct float_bit_mask -{ - /// \brief Value of this trait - template - struct values - { - ROCPRIM_DO_NOT_COMPILE_IF(number_format::get().value - != number_format::kind::integral_type, - "BitType should be integral"); - /// \brief Trait sign_bit for the `InputType`. - static constexpr BitType sign_bit = SignBit; - /// \brief Trait exponent for the `InputType`. - static constexpr BitType exponent = Exponent; - /// \brief Trait mantissa for the `InputType`. - static constexpr BitType mantissa = Mantissa; - }; - -#ifndef DOXYGEN_DOCUMENTATION_BUILD - - ROCPRIM_TRAITS_GENERATE_IS_DEFINE(float_bit_mask); - - template - static constexpr bool has_old_float_bit_mask = false; - template - static constexpr bool has_old_float_bit_mask< - InputType, - detail::void_t{})>> - = true; - - // If this trait is defined, then use the new interface - template)> - static constexpr auto get() - { - ROCPRIM_DO_NOT_COMPILE_IF( - number_format::get().value != number_format::kind::floating_point_type, - "You cannot use trait `float_bit_mask` for `non-floating_point` types"); - return typename define::float_bit_mask{}; - } - - // This function acts as a bridge for old interface. Will be removed in certain version - // "`rocprim::detail::float_bit_mask` will be deprecated on next main release," - // "`please use rocprim::trait::define` to define tratis for types." - template && has_old_float_bit_mask)> - static constexpr auto get() - { - using mask = typename ::rocprim::detail::float_bit_mask; - return values{}; - } - - // For types that don't have a trait `float_bit_mask` defined neither a rocprim::detail::float_bit_mask specialization - template && !has_old_float_bit_mask)> - static constexpr auto get() - { - ROCPRIM_DO_NOT_COMPILE_IF( - number_format::get().value != number_format::kind::floating_point_type, - "You cannot use trait `float_bit_mask` for `non-floating_point` types"); - ROCPRIM_DO_NOT_COMPILE_IF(number_format::get().value - == number_format::kind::floating_point_type, - "Trait `float_bit_mask` is required for `floating_point` types"); - return values{}; - } -#endif -}; - -/// \brief The trait `is_fundamental` is undefinable, as it is the union of `std::is_fundamental` -/// and `rocprim::traits::is_arithmetic`. -/// \par Definability -/// * **Undefinable**: If you attempt to define this trait in any form, a compile-time error will occur. -/// \par How to use -/// \parblock -/// \code{.cpp} -/// rocprim::traits::get().is_fundamental(); -/// rocprim::traits::get().is_compound(); -/// \endcode -/// \endparblock -struct is_fundamental -{ - - /// \brief Value of this trait - template - struct values - { - /// \brief This indicates if the `InputType` is fundamental. - static constexpr auto value = Val; - }; - -#ifndef DOXYGEN_DOCUMENTATION_BUILD - - ROCPRIM_TRAITS_GENERATE_IS_DEFINE(is_fundamental); - - // For all types - template - static constexpr auto get() - { - ROCPRIM_DO_NOT_COMPILE_IF(is_defined, "Trait `is_fundamental` is undefinable"); - return values < std::is_fundamental::value - || is_arithmetic::get().value > {}; - } -#endif -}; - -/// @} - -/// \addtogroup type_traits_interfaces -/// @{ - -/// \par Overview -/// This template struct is designed to allow rocPRIM algorithms to retrieve trait information from C++ -/// build-in arithmetic types, rocPRIM types, and custom types. This API is not static because of ODR. -/// * All member functions are `compiled only when invoked`. -/// * Different algorithms require different traits. -/// \tparam T The type from which you want to retrieve the traits. -/// \par Example -/// \parblock -/// The following code demonstrates how to retrieve the traits of type `T`. -/// \code{.cpp} -/// // Get the trait in a template parameter -/// template().is_integral()>::type* = nullptr> -/// void get_traits_in_template_parameter(){} -/// // Get the trait in a function body -/// template -/// void get_traits_in_function_body(){ -/// constexpr auto input_traits = rocprim::traits::get(); -/// // Then you can use the member functinos -/// constexpr bool is_arithmetic = input_traits.is_arithmetic(); -/// } -/// \endcode -/// \endparblock -template -struct get -{ - /// \brief Get the value of trait `is_arithmetic`. - /// \returns `true` if `std::is_arithmetic_v` is `true`, or if type `T` is a rocPRIM arithmetic - /// type, or if the `is_arithmetic` trait has been defined as `true`; otherwise, returns `false`. - constexpr bool is_arithmetic() const - { - return rocprim::traits::is_arithmetic{}.get().value; - }; - - /// \brief Get trait `is_fundamental`. - /// \returns `true` if `T` is a fundamental type (that is, rocPRIM arithmetic type, void, or nullptr_t); - /// otherwise, returns `false`. - constexpr bool is_fundamental() const - { - return rocprim::traits::is_fundamental{}.get().value; - }; - - /// \brief If `T` is fundamental type, then returns `false`. - /// \returns `false` if `T` is a fundamental type (that is, rocPRIM arithmetic type, void, or nullptr_t); - /// otherwise, returns `true`. - constexpr bool is_compound() const - { - return !rocprim::traits::is_fundamental{}.get().value; - } - - /// \brief To check if `T` is floating-point type. - /// \warning You cannot call this function when `is_arithmetic()` returns `false`; - /// doing so will result in a compile-time error. - constexpr bool is_floating_point() const - { - return rocprim::traits::number_format{}.get().value - == number_format::kind::floating_point_type; - }; - - /// \brief To check if `T` is integral type. - /// \warning You cannot call this function when `is_arithmetic()` returns `false`; - /// doing so will result in a compile-time error. - constexpr bool is_integral() const - { - return rocprim::traits::number_format{}.get().value - == number_format::kind::integral_type; - } - - /// \brief To check if `T` is signed integral type. - /// \warning You cannot call this function when `is_integral()` returns `false`; - /// doing so will result in a compile-time error. - constexpr bool is_signed() const - { - return rocprim::traits::integral_sign{}.get().value == integral_sign::kind::signed_type; - } - - /// \brief To check if `T` is unsigned integral type. - /// \warning You cannot call this function when `is_integral()` returns `false`; - /// doing so will result in a compile-time error. - constexpr bool is_unsigned() const - { - return rocprim::traits::integral_sign{}.get().value - == integral_sign::kind::unsigned_type; - } - - /// \brief Get trait `is_scalar`. - /// \returns `true` if `std::is_scalar_v` is `true`, or if type `T` is a rocPRIM arithmetic - /// type, or if the `is_scalar` trait has been defined as `true`; otherwise, returns `false`. - constexpr bool is_scalar() const - { - return rocprim::traits::is_scalar{}.get().value; - } - - /// \brief Get trait `float_bit_mask`. - /// \warning You cannot call this function when `is_floating_point()` returns `false`; - /// doing so will result in a compile-time error. - /// \returns A constexpr instance of the specialization of `rocprim::traits::float_bit_mask::values` - /// as provided in the traits definition of type T. If the `float_bit_mask trait` is not defined, it - /// returns the rocprim::detail::float_bit_mask values, provided a specialization of - /// `rocprim::detail::float_bit_mask` exists. - constexpr auto float_bit_mask() const - { - return rocprim::traits::float_bit_mask{}.get(); - }; -}; - -/// @} - -} // namespace traits - -/// \defgroup rocprim_pre_defined_traits Trait definitions for rocPRIM arithmetic types and additional traits for -/// C++ build-in arithmetic types. -/// \addtogroup rocprim_pre_defined_traits -/// @{ - -/// \brief This is the definition of traits of `float` -/// C++ build-in type -template<> -struct traits::define -{ - /// \brief Trait `float_bit_mask` for this type - using float_bit_mask - = traits::float_bit_mask::values; -}; - -/// \brief This is the definition of traits of `double` -/// C++ build-in type -template<> -struct traits::define -{ - /// \brief Trait `float_bit_mask` for this type - using float_bit_mask = traits::float_bit_mask:: - values; -}; - -/// \brief This is the definition of traits of `rocprim::bfloat16` -/// rocPRIM arithmetic type -template<> -struct traits::define -{ - /// \brief Trait `is_arithmetic` for this type - using is_arithmetic = traits::is_arithmetic::values; - /// \brief Trait `number_format` for this type - using number_format - = traits::number_format::values; - /// \brief Trait `float_bit_mask` for this type - using float_bit_mask = traits::float_bit_mask::values; -}; - -/// \brief This is the definition of traits of `rocprim::half` -/// rocPRIM arithmetic type -template<> -struct traits::define -{ - /// \brief Trait `is_arithmetic` for this type - using is_arithmetic = traits::is_arithmetic::values; - /// \brief Trait `number_format` for this type - using number_format - = traits::number_format::values; - /// \brief Trait `float_bit_mask` for this type - using float_bit_mask = traits::float_bit_mask::values; -}; - -// Type traits like std::is_integral and std::is_arithmetic may be defined for 128-bit integral -// types (__int128_t and __uint128_t) in several cases: -// * with libstdc++ when GNU extensions are enabled (-std=gnu++17, which is the default C++ -// standard in clang); -// * always with libc++ (it is used on HIP SDK for Windows). - -namespace detail -{ - -struct define_int128_t -{ - /// \brief Trait `is_arithmetic` for this type - using is_arithmetic = traits::is_arithmetic::values; - /// \brief Trait `number_format` for this type - using number_format = traits::number_format::values; - /// \brief Trait `integral_sign` for this type - using integral_sign = traits::integral_sign::values; -}; - -struct define_uint128_t -{ - /// \brief Trait `is_arithmetic` for this type - using is_arithmetic = traits::is_arithmetic::values; - /// \brief Trait `number_format` for this type - using number_format = traits::number_format::values; - /// \brief Trait `integral_sign` for this type - using integral_sign = traits::integral_sign::values; -}; - -} // namespace detail - -/// \brief This is the definition of traits of `rocprim::int128_t` -/// rocPRIM arithmetic type -template<> -struct traits::define - : std::conditional_t::value, - traits::define, - detail::define_int128_t> -{}; - -/// \brief This is the definition of traits of `rocprim::uint128_t` -/// rocPRIM arithmetic type -template<> -struct traits::define - : std::conditional_t::value, - traits::define, - detail::define_uint128_t> -{}; - -/// @} - -/// \defgroup rocprim_type_traits_wrapper Handy wrappers for obtaining type traits -/// \addtogroup rocprim_type_traits_wrapper -/// @{ - -/// \brief An extension of `std::is_floating_point` that supports additional arithmetic types, -/// including `rocprim::half`, `rocprim::bfloat16`, and any types with trait -/// `rocprim::traits::number_format::values` implemented. -template -struct is_floating_point - : std::integral_constant().is_floating_point()> -{}; - -/// \brief An extension of `std::is_integral` that supports additional arithmetic types, -/// including `rocprim::int128_t`, `rocprim::uint128_t`, and any types with trait -/// `rocprim::traits::number_format::values` implemented. -template -struct is_integral : std::integral_constant().is_integral()> -{}; - -/// \brief An extension of `std::is_arithmetic` that supports additional arithmetic types, -/// including any types with trait `rocprim::traits::is_arithmetic::values` implemented. -template -struct is_arithmetic : std::integral_constant().is_arithmetic()> -{}; - -/// \brief An extension of `std::is_fundamental` that supports additional arithmetic types, -/// including any types with trait `rocprim::traits::is_arithmetic::values` implemented. -template -struct is_fundamental : std::integral_constant().is_fundamental()> -{}; - -/// \brief An extension of `std::is_unsigned` that supports additional arithmetic types, -/// including `rocprim::uint128_t`, and any types with trait -/// `rocprim::traits::integral_sign::values` implemented. -template -struct is_unsigned : std::integral_constant().is_unsigned()> -{}; - -/// \brief An extension of `std::is_signed` that supports additional arithmetic types, -/// including `rocprim::int128_t`, and any types with trait -/// `rocprim::traits::integral_sign::values` implemented. -template -struct is_signed : std::integral_constant().is_signed()> -{}; - -/// \brief An extension of `std::is_scalar` that supports additional arithmetic types, -/// including any types with trait `rocprim::traits::is_scalar::values` implemented. -template -struct is_scalar : std::integral_constant().is_scalar()> -{}; - -/// \brief An extension of `std::is_scalar` that supports additional non-arithmetic types. -template -struct is_compound : std::integral_constant().is_compound()> -{}; - -/// @} -END_ROCPRIM_NAMESPACE - -#endif diff --git a/rocprim/include/rocprim/types.hpp b/rocprim/include/rocprim/types.hpp index fa1e2f40e..4f3dab9ee 100644 --- a/rocprim/include/rocprim/types.hpp +++ b/rocprim/include/rocprim/types.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -35,6 +35,8 @@ #include "types/tuple.hpp" #include "types/uninitialized_array.hpp" +#include "intrinsics/arch.hpp" + /// \addtogroup utilsmodule /// @{ @@ -54,7 +56,8 @@ struct make_vector_type /// \brief Empty type used as a placeholder, usually used to flag that given /// template parameter should not be used. -struct empty_type {}; +struct empty_type +{}; /// \brief A decomposer that must be passed to the radix sort algorithms when /// sorting keys that are arithmetic types. @@ -74,14 +77,19 @@ using bfloat16 = ::hip_bfloat16; /// \brief The lane_mask_type is an integer that contains one bit per thread. /// -/// The total number of bits is equal to the total number of threads in a -/// warp. Used to for warp-level operations. -/// \note This is defined only on the device side, see `ROCPRIM_WAVEFRONT_SIZE` for details. -#if ROCPRIM_WAVEFRONT_SIZE == 32 -using lane_mask_type = unsigned int; -#elif ROCPRIM_WAVEFRONT_SIZE == 64 -using lane_mask_type = unsigned long long int; -#endif +/// When targeting AMDGCN, the total number of bits is equal to the total +/// number of threads in a warp. Used for warp-level operations. When +/// targeting SPIR-V, it assumes 64 threads per warp. +/// +/// \note When called on the host, assumes 64-bit wide masks. +/// +/// \note When targeting SPIR-V, this type will be 64-bit wide. Extra +/// precaution must be taken as the number of bits in this type is not +/// always the same as the number of lanes. +using lane_mask_type = std::conditional_t<::rocprim::arch::wavefront::get_target() + == ::rocprim::arch::wavefront::target::size32, + unsigned int, + unsigned long long int>; /// \brief Native half-precision floating point type using native_half = _Float16; diff --git a/rocprim/include/rocprim/warp/detail/warp_reduce_crosslane.hpp b/rocprim/include/rocprim/warp/detail/warp_reduce_crosslane.hpp index e9cba59bf..9a9706c83 100644 --- a/rocprim/include/rocprim/warp/detail/warp_reduce_crosslane.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_reduce_crosslane.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2020 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -33,18 +33,14 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class T, - unsigned int WarpSize, - bool UseAllReduce, - bool UseDPP = ROCPRIM_DETAIL_USE_DPP -> +template using warp_reduce_crosslane = - typename std::conditional< - UseDPP, - warp_reduce_dpp, - warp_reduce_shuffle - >::type; + typename std::conditional, + warp_reduce_shuffle>::type; } // end namespace detail diff --git a/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp b/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp index c2b147ad1..5633e4506 100644 --- a/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp @@ -24,9 +24,9 @@ #include #include "../../config.hpp" +#include "../../detail/various.hpp" #include "../../intrinsics.hpp" #include "../../types.hpp" -#include "../../detail/various.hpp" #include "warp_reduce_shuffle.hpp" @@ -35,15 +35,11 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class T, - unsigned int WarpSize, - bool UseAllReduce -> +template class warp_reduce_dpp { public: - static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2"); + static_assert(detail::is_power_of_two(VirtualWaveSize), "VirtualWaveSize must be power of 2"); using storage_type = detail::empty_storage_type; @@ -53,24 +49,24 @@ class warp_reduce_dpp { output = input; - if(WarpSize > 1) + if(VirtualWaveSize > 1) { // quad_perm:[1,0,3,2] -> 10110001 output = reduce_op(warp_move_dpp(output), output); } - if(WarpSize > 2) + if(VirtualWaveSize > 2) { // quad_perm:[2,3,0,1] -> 01001110 output = reduce_op(warp_move_dpp(output), output); } - if(WarpSize > 4) + if(VirtualWaveSize > 4) { // row_ror:4 // Use rotation instead of shift to avoid leaving invalid values in the destination // registers (asume warp size of at least hardware warp-size) output = reduce_op(warp_move_dpp(output), output); } - if(WarpSize > 8) + if(VirtualWaveSize > 8) { // row_ror:8 // Use rotation instead of shift to avoid leaving invalid values in the destination @@ -78,34 +74,35 @@ class warp_reduce_dpp output = reduce_op(warp_move_dpp(output), output); } #ifdef ROCPRIM_DETAIL_HAS_DPP_BROADCAST - if(WarpSize > 16) + if(VirtualWaveSize > 16) { // row_bcast:15 output = reduce_op(warp_move_dpp(output), output); } - if(WarpSize > 32) + if(VirtualWaveSize > 32) { // row_bcast:31 output = reduce_op(warp_move_dpp(output), output); } - static_assert(WarpSize <= 64, "WarpSize > 64 is not supported"); + static_assert(VirtualWaveSize <= 64, "VirtualWaveSize > 64 is not supported"); #else - if(WarpSize > 16) + if(VirtualWaveSize > 16) { // row_bcast:15 output = reduce_op(warp_swizzle(output), output); } - static_assert(WarpSize <= 32, "WarpSize > 32 is not supported without DPP broadcasts"); + static_assert(VirtualWaveSize <= 32, + "VirtualWaveSize > 32 is not supported without DPP broadcasts"); #endif // Read the result from the last lane of the logical warp - output = warp_shuffle(output, WarpSize - 1, WarpSize); + output = warp_shuffle(output, VirtualWaveSize - 1, VirtualWaveSize); } template ROCPRIM_DEVICE ROCPRIM_INLINE void reduce_impl(T input, T& output, BinaryFunction reduce_op, std::true_type) { - warp_reduce_shuffle().reduce(input, output, reduce_op); + warp_reduce_shuffle().reduce(input, output, reduce_op); } template @@ -116,14 +113,15 @@ class warp_reduce_dpp input, output, reduce_op, - std::integral_constant{}); + std::integral_constant{}); } template ROCPRIM_DEVICE ROCPRIM_INLINE void reduce(T input, T& output, storage_type& storage, BinaryFunction reduce_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning this->reduce(input, output, reduce_op); } @@ -132,16 +130,21 @@ class warp_reduce_dpp void reduce(T input, T& output, unsigned int valid_items, BinaryFunction reduce_op) { // Fallback to shuffle-based implementation - warp_reduce_shuffle() - .reduce(input, output, valid_items, reduce_op); + warp_reduce_shuffle().reduce(input, + output, + valid_items, + reduce_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void reduce(T input, T& output, unsigned int valid_items, - storage_type& storage, BinaryFunction reduce_op) + void reduce(T input, + T& output, + unsigned int valid_items, + storage_type& storage, + BinaryFunction reduce_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning this->reduce(input, output, valid_items, reduce_op); } @@ -150,8 +153,10 @@ class warp_reduce_dpp void head_segmented_reduce(T input, T& output, Flag flag, BinaryFunction reduce_op) { // Fallback to shuffle-based implementation - warp_reduce_shuffle() - .head_segmented_reduce(input, output, flag, reduce_op); + warp_reduce_shuffle().head_segmented_reduce(input, + output, + flag, + reduce_op); } template @@ -159,28 +164,36 @@ class warp_reduce_dpp void tail_segmented_reduce(T input, T& output, Flag flag, BinaryFunction reduce_op) { // Fallback to shuffle-based implementation - warp_reduce_shuffle() - .tail_segmented_reduce(input, output, flag, reduce_op); + warp_reduce_shuffle().tail_segmented_reduce(input, + output, + flag, + reduce_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void head_segmented_reduce(T input, T& output, Flag flag, - storage_type& storage, BinaryFunction reduce_op) + void head_segmented_reduce( + T input, T& output, Flag flag, storage_type& storage, BinaryFunction reduce_op) { // Fallback to shuffle-based implementation - warp_reduce_shuffle() - .head_segmented_reduce(input, output, flag, storage, reduce_op); + warp_reduce_shuffle().head_segmented_reduce(input, + output, + flag, + storage, + reduce_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void tail_segmented_reduce(T input, T& output, Flag flag, - storage_type& storage, BinaryFunction reduce_op) + void tail_segmented_reduce( + T input, T& output, Flag flag, storage_type& storage, BinaryFunction reduce_op) { // Fallback to shuffle-based implementation - warp_reduce_shuffle() - .tail_segmented_reduce(input, output, flag, storage, reduce_op); + warp_reduce_shuffle().tail_segmented_reduce(input, + output, + flag, + storage, + reduce_op); } }; diff --git a/rocprim/include/rocprim/warp/detail/warp_reduce_shared_mem.hpp b/rocprim/include/rocprim/warp/detail/warp_reduce_shared_mem.hpp index 7c1d84074..04af8922c 100644 --- a/rocprim/include/rocprim/warp/detail/warp_reduce_shared_mem.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_reduce_shared_mem.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,9 +24,9 @@ #include #include "../../config.hpp" +#include "../../detail/various.hpp" #include "../../intrinsics.hpp" #include "../../types.hpp" -#include "../../detail/various.hpp" #include "warp_segment_bounds.hpp" @@ -35,16 +35,12 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class T, - unsigned int WarpSize, - bool UseAllReduce -> +template class warp_reduce_shared_mem { struct storage_type_ { - T values[WarpSize]; + T values[VirtualWaveSize]; }; public: @@ -56,17 +52,17 @@ class warp_reduce_shared_mem ROCPRIM_DEVICE ROCPRIM_INLINE void reduce(T input, T& output, storage_type& storage, BinaryFunction reduce_op) { - constexpr unsigned int ceiling = next_power_of_two(WarpSize); - const unsigned int lid = detail::logical_lane_id(); - storage_type_& storage_ = storage.get(); + constexpr unsigned int ceiling = next_power_of_two(VirtualWaveSize); + const unsigned int lid = detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); - output = input; + output = input; storage_.values[lid] = output; ::rocprim::wave_barrier(); ROCPRIM_UNROLL for(unsigned int i = ceiling >> 1; i > 0; i >>= 1) { - const bool do_op = lid + i < WarpSize && lid < i; + const bool do_op = lid + i < VirtualWaveSize && lid < i; if(do_op) { output = storage_.values[lid]; @@ -85,20 +81,23 @@ class warp_reduce_shared_mem template ROCPRIM_DEVICE ROCPRIM_INLINE - void reduce(T input, T& output, unsigned int valid_items, - storage_type& storage, BinaryFunction reduce_op) + void reduce(T input, + T& output, + unsigned int valid_items, + storage_type& storage, + BinaryFunction reduce_op) { - constexpr unsigned int ceiling = next_power_of_two(WarpSize); - const unsigned int lid = detail::logical_lane_id(); - storage_type_& storage_ = storage.get(); + constexpr unsigned int ceiling = next_power_of_two(VirtualWaveSize); + const unsigned int lid = detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); - output = input; + output = input; storage_.values[lid] = output; ::rocprim::wave_barrier(); ROCPRIM_UNROLL for(unsigned int i = ceiling >> 1; i > 0; i >>= 1) { - const bool do_op = (lid + i) < WarpSize && lid < i && (lid + i) < valid_items; + const bool do_op = (lid + i) < VirtualWaveSize && lid < i && (lid + i) < valid_items; if(do_op) { output = storage_.values[lid]; @@ -117,16 +116,16 @@ class warp_reduce_shared_mem template ROCPRIM_DEVICE ROCPRIM_INLINE - void head_segmented_reduce(T input, T& output, Flag flag, - storage_type& storage, BinaryFunction reduce_op) + void head_segmented_reduce( + T input, T& output, Flag flag, storage_type& storage, BinaryFunction reduce_op) { this->segmented_reduce(input, output, flag, storage, reduce_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void tail_segmented_reduce(T input, T& output, Flag flag, - storage_type& storage, BinaryFunction reduce_op) + void tail_segmented_reduce( + T input, T& output, Flag flag, storage_type& storage, BinaryFunction reduce_op) { this->segmented_reduce(input, output, flag, storage, reduce_op); } @@ -134,14 +133,14 @@ class warp_reduce_shared_mem private: template ROCPRIM_DEVICE ROCPRIM_INLINE - void segmented_reduce(T input, T& output, Flag flag, - storage_type& storage, BinaryFunction reduce_op) + void segmented_reduce( + T input, T& output, Flag flag, storage_type& storage, BinaryFunction reduce_op) { - const unsigned int lid = detail::logical_lane_id(); - constexpr unsigned int ceiling = next_power_of_two(WarpSize); - storage_type_& storage_ = storage.get(); + const unsigned int lid = detail::logical_lane_id(); + constexpr unsigned int ceiling = next_power_of_two(VirtualWaveSize); + storage_type_& storage_ = storage.get(); // Get logical lane id of the last valid value in the segment - auto last = last_in_warp_segment(flag); + auto last = last_in_warp_segment(flag); output = input; ROCPRIM_UNROLL @@ -152,7 +151,7 @@ class warp_reduce_shared_mem if((lid + i) <= last) { T other = storage_.values[lid + i]; - output = reduce_op(output, other); + output = reduce_op(output, other); } ::rocprim::wave_barrier(); } @@ -160,18 +159,16 @@ class warp_reduce_shared_mem template ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(Switch == false)>::type - set_output(T& output, storage_type& storage) + typename std::enable_if<(Switch == false)>::type set_output(T& output, storage_type& storage) { - (void) output; - (void) storage; + (void)output; + (void)storage; // output already set correctly } template ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(Switch == true)>::type - set_output(T& output, storage_type& storage) + typename std::enable_if<(Switch == true)>::type set_output(T& output, storage_type& storage) { storage_type_& storage_ = storage.get(); output = storage_.values[0]; diff --git a/rocprim/include/rocprim/warp/detail/warp_reduce_shuffle.hpp b/rocprim/include/rocprim/warp/detail/warp_reduce_shuffle.hpp index 3a7bfd383..15273c87d 100644 --- a/rocprim/include/rocprim/warp/detail/warp_reduce_shuffle.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_reduce_shuffle.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -24,9 +24,9 @@ #include #include "../../config.hpp" +#include "../../detail/various.hpp" #include "../../intrinsics.hpp" #include "../../types.hpp" -#include "../../detail/various.hpp" #include "warp_segment_bounds.hpp" @@ -35,15 +35,11 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class T, - unsigned int WarpSize, - bool UseAllReduce -> +template class warp_reduce_shuffle { public: - static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2"); + static_assert(detail::is_power_of_two(VirtualWaveSize), "VirtualWaveSize must be power of 2"); using storage_type = detail::empty_storage_type; @@ -55,9 +51,9 @@ class warp_reduce_shuffle T value; ROCPRIM_UNROLL - for(unsigned int offset = 1; offset < WarpSize; offset *= 2) + for(unsigned int offset = 1; offset < VirtualWaveSize; offset *= 2) { - value = warp_shuffle_down(output, offset, WarpSize); + value = warp_shuffle_down(output, offset, VirtualWaveSize); output = reduce_op(output, value); } set_output(output); @@ -67,7 +63,7 @@ class warp_reduce_shuffle ROCPRIM_DEVICE ROCPRIM_INLINE void reduce(T input, T& output, storage_type& storage, BinaryFunction reduce_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning this->reduce(input, output, reduce_op); } @@ -79,21 +75,25 @@ class warp_reduce_shuffle T value; ROCPRIM_UNROLL - for(unsigned int offset = 1; offset < WarpSize; offset *= 2) + for(unsigned int offset = 1; offset < VirtualWaveSize; offset *= 2) { - value = warp_shuffle_down(output, offset, WarpSize); - unsigned int id = detail::logical_lane_id(); - if (id + offset < valid_items) output = reduce_op(output, value); + value = warp_shuffle_down(output, offset, VirtualWaveSize); + unsigned int id = detail::logical_lane_id(); + if(id + offset < valid_items) + output = reduce_op(output, value); } set_output(output); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void reduce(T input, T& output, unsigned int valid_items, - storage_type& storage, BinaryFunction reduce_op) + void reduce(T input, + T& output, + unsigned int valid_items, + storage_type& storage, + BinaryFunction reduce_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning this->reduce(input, output, valid_items, reduce_op); } @@ -113,19 +113,19 @@ class warp_reduce_shuffle template ROCPRIM_DEVICE ROCPRIM_INLINE - void head_segmented_reduce(T input, T& output, Flag flag, - storage_type& storage, BinaryFunction reduce_op) + void head_segmented_reduce( + T input, T& output, Flag flag, storage_type& storage, BinaryFunction reduce_op) { - (void) storage; + (void)storage; this->segmented_reduce(input, output, flag, reduce_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void tail_segmented_reduce(T input, T& output, Flag flag, - storage_type& storage, BinaryFunction reduce_op) + void tail_segmented_reduce( + T input, T& output, Flag flag, storage_type& storage, BinaryFunction reduce_op) { - (void) storage; + (void)storage; this->segmented_reduce(input, output, flag, reduce_op); } @@ -136,25 +136,24 @@ class warp_reduce_shuffle { // Get logical lane id of the last valid value in the segment, // and convert it to number of valid values in segment. - auto valid_items_in_segment = last_in_warp_segment(flag) + 1U; + auto valid_items_in_segment + = last_in_warp_segment(flag) + 1U; this->reduce(input, output, valid_items_in_segment, reduce_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(Switch == false)>::type - set_output(T& output) + typename std::enable_if<(Switch == false)>::type set_output(T& output) { - (void) output; + (void)output; // output already set correctly } template ROCPRIM_DEVICE ROCPRIM_INLINE - typename std::enable_if<(Switch == true)>::type - set_output(T& output) + typename std::enable_if<(Switch == true)>::type set_output(T& output) { - output = warp_shuffle(output, 0, WarpSize); + output = warp_shuffle(output, 0, VirtualWaveSize); } }; diff --git a/rocprim/include/rocprim/warp/detail/warp_scan_crosslane.hpp b/rocprim/include/rocprim/warp/detail/warp_scan_crosslane.hpp index acf7bbb7d..92cc4f5b4 100644 --- a/rocprim/include/rocprim/warp/detail/warp_scan_crosslane.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_scan_crosslane.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2020 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -33,17 +33,10 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class T, - unsigned int WarpSize, - bool UseDPP = ROCPRIM_DETAIL_USE_DPP -> -using warp_scan_crosslane = - typename std::conditional< - UseDPP, - warp_scan_dpp, - warp_scan_shuffle - >::type; +template +using warp_scan_crosslane = typename std::conditional, + warp_scan_shuffle>::type; } // end namespace detail diff --git a/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp b/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp index 8422cbf61..b56108862 100644 --- a/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp @@ -34,15 +34,11 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template< - class T, - unsigned int WarpSize -> +template class warp_scan_dpp { public: - static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2"); + static_assert(detail::is_power_of_two(VirtualWaveSize), "VirtualWaveSize must be power of 2"); using storage_type = detail::empty_storage_type; @@ -50,60 +46,66 @@ class warp_scan_dpp ROCPRIM_DEVICE ROCPRIM_INLINE void inclusive_scan(T input, T& output, BinaryFunction scan_op) { - const unsigned int lane_id = ::rocprim::lane_id(); - const unsigned int row_lane_id = lane_id % ::rocprim::min(16u, WarpSize); + const unsigned int lane_id = ::rocprim::lane_id(); + const unsigned int row_lane_id = lane_id % ::rocprim::min(16u, VirtualWaveSize); output = input; - if(WarpSize > 1) + if(VirtualWaveSize > 1) { T t = scan_op(warp_move_dpp(output), output); // row_shr:1 - if(row_lane_id >= 1) output = t; + if(row_lane_id >= 1) + output = t; } - if(WarpSize > 2) + if(VirtualWaveSize > 2) { T t = scan_op(warp_move_dpp(output), output); // row_shr:2 - if(row_lane_id >= 2) output = t; + if(row_lane_id >= 2) + output = t; } - if(WarpSize > 4) + if(VirtualWaveSize > 4) { T t = scan_op(warp_move_dpp(output), output); // row_shr:4 - if(row_lane_id >= 4) output = t; + if(row_lane_id >= 4) + output = t; } - if(WarpSize > 8) + if(VirtualWaveSize > 8) { T t = scan_op(warp_move_dpp(output), output); // row_shr:8 - if(row_lane_id >= 8) output = t; + if(row_lane_id >= 8) + output = t; } #ifdef ROCPRIM_DETAIL_HAS_DPP_BROADCAST - if(WarpSize > 16) + if(VirtualWaveSize > 16) { T t = scan_op(warp_move_dpp(output), output); // row_bcast:15 - if(lane_id % 32 >= 16) output = t; + if(lane_id % 32 >= 16) + output = t; } - if(WarpSize > 32) + if(VirtualWaveSize > 32) { T t = scan_op(warp_move_dpp(output), output); // row_bcast:31 - if(lane_id >= 32) output = t; + if(lane_id >= 32) + output = t; } - static_assert(WarpSize <= 64, "WarpSize > 64 is not supported"); + static_assert(VirtualWaveSize <= 64, "VirtualWaveSize > 64 is not supported"); #else - if(WarpSize > 16) + if(VirtualWaveSize > 16) { T t = scan_op(warp_swizzle(output), output); // row_bcast:15 if(lane_id % 32 >= 16) output = t; } - static_assert(WarpSize <= 32, "WarpSize > 32 is not supported without DPP broadcasts"); + static_assert(VirtualWaveSize <= 32, + "VirtualWaveSize > 32 is not supported without DPP broadcasts"); #endif } template ROCPRIM_DEVICE ROCPRIM_INLINE - void inclusive_scan(T input, T& output, - storage_type& storage, BinaryFunction scan_op) + void inclusive_scan(T input, T& output, storage_type& storage, BinaryFunction scan_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning inclusive_scan(input, output, scan_op); } @@ -126,20 +128,19 @@ class warp_scan_dpp template ROCPRIM_DEVICE ROCPRIM_INLINE - void inclusive_scan(T input, T& output, T& reduction, - BinaryFunction scan_op) + void inclusive_scan(T input, T& output, T& reduction, BinaryFunction scan_op) { inclusive_scan(input, output, scan_op); - // Broadcast value from the last thread in warp - reduction = warp_shuffle(output, WarpSize-1, WarpSize); + // Broadcast value from the last thread in the warp + reduction = warp_shuffle(output, VirtualWaveSize - 1, VirtualWaveSize); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void inclusive_scan(T input, T& output, T& reduction, - storage_type& storage, BinaryFunction scan_op) + void inclusive_scan( + T input, T& output, T& reduction, storage_type& storage, BinaryFunction scan_op) { - (void) storage; + (void)storage; inclusive_scan(input, output, reduction, scan_op); } @@ -148,10 +149,10 @@ class warp_scan_dpp void inclusive_scan(T input, T& output, T& reduction, BinaryFunction scan_op, T init) { inclusive_scan(input, output, scan_op); + // Broadcast value from the last thread in the warp + reduction = warp_shuffle(output, VirtualWaveSize - 1, VirtualWaveSize); // Include init value in scan results output = scan_op(init, output); - // Broadcast value from the last thread in warp - reduction = warp_shuffle(output, WarpSize - 1, WarpSize); } template @@ -174,59 +175,57 @@ class warp_scan_dpp template ROCPRIM_DEVICE ROCPRIM_INLINE - void exclusive_scan(T input, T& output, T init, - storage_type& storage, BinaryFunction scan_op) + void exclusive_scan(T input, T& output, T init, storage_type& storage, BinaryFunction scan_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning exclusive_scan(input, output, init, scan_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void exclusive_scan(T input, T& output, - storage_type& storage, BinaryFunction scan_op) + void exclusive_scan(T input, T& output, storage_type& storage, BinaryFunction scan_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning inclusive_scan(input, output, scan_op); // Convert inclusive scan result to exclusive to_exclusive(output, output); } template - ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan( + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan( T input, T& output, storage_type& /*storage*/, T& reduction, BinaryFunction scan_op) { inclusive_scan(input, output, scan_op); - // Broadcast value from the last thread in warp - reduction = warp_shuffle(output, WarpSize - 1, WarpSize); + // Broadcast value from the last thread in the warp + reduction = warp_shuffle(output, VirtualWaveSize - 1, VirtualWaveSize); // Convert inclusive scan result to exclusive to_exclusive(output, output); } template - ROCPRIM_DEVICE ROCPRIM_INLINE void - exclusive_scan(T input, T& output, T init, T& reduction, BinaryFunction scan_op) + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan(T input, T& output, T init, T& reduction, BinaryFunction scan_op) { inclusive_scan(input, output, scan_op); - // Broadcast value from the last thread in warp - reduction = warp_shuffle(output, WarpSize-1, WarpSize); + // Broadcast value from the last thread in the warp + reduction = warp_shuffle(output, VirtualWaveSize - 1, VirtualWaveSize); // Convert inclusive scan result to exclusive to_exclusive(output, output, init, scan_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void exclusive_scan(T input, T& output, T init, T& reduction, - storage_type& storage, BinaryFunction scan_op) + void exclusive_scan( + T input, T& output, T init, T& reduction, storage_type& storage, BinaryFunction scan_op) { - (void) storage; + (void)storage; exclusive_scan(input, output, init, reduction, scan_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, T init, - BinaryFunction scan_op) + void scan(T input, T& inclusive_output, T& exclusive_output, T init, BinaryFunction scan_op) { inclusive_scan(input, inclusive_output, scan_op); // Convert inclusive scan result to exclusive @@ -235,19 +234,26 @@ class warp_scan_dpp template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, T init, - storage_type& storage, BinaryFunction scan_op) + void scan(T input, + T& inclusive_output, + T& exclusive_output, + T init, + storage_type& storage, + BinaryFunction scan_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning scan(input, inclusive_output, exclusive_output, init, scan_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, - storage_type& storage, BinaryFunction scan_op) + void scan(T input, + T& inclusive_output, + T& exclusive_output, + storage_type& storage, + BinaryFunction scan_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning inclusive_scan(input, inclusive_output, scan_op); // Convert inclusive scan result to exclusive to_exclusive(inclusive_output, exclusive_output); @@ -255,58 +261,58 @@ class warp_scan_dpp template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction, + void scan(T input, + T& inclusive_output, + T& exclusive_output, + T init, + T& reduction, BinaryFunction scan_op) { inclusive_scan(input, inclusive_output, scan_op); - // Broadcast value from the last thread in warp - reduction = warp_shuffle(inclusive_output, WarpSize-1, WarpSize); + // Broadcast value from the last thread in the warp + reduction = warp_shuffle(inclusive_output, VirtualWaveSize - 1, VirtualWaveSize); // Convert inclusive scan result to exclusive to_exclusive(inclusive_output, exclusive_output, init, scan_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction, - storage_type& storage, BinaryFunction scan_op) + void scan(T input, + T& inclusive_output, + T& exclusive_output, + T init, + T& reduction, + storage_type& storage, + BinaryFunction scan_op) { - (void) storage; + (void)storage; scan(input, inclusive_output, exclusive_output, init, reduction, scan_op); } ROCPRIM_DEVICE ROCPRIM_INLINE T broadcast(T input, const unsigned int src_lane, storage_type& storage) { - (void) storage; + (void)storage; - if(WarpSize == ::rocprim::arch::wavefront::min_size()) + if(VirtualWaveSize == ::rocprim::arch::wavefront::size()) { return warp_readlane(input, warp_readfirstlane(src_lane)); } - return warp_shuffle(input, src_lane, WarpSize); - } - -protected: - [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE void - to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) - { - (void) storage; - return to_exclusive(inclusive_input, exclusive_output); + return warp_shuffle(input, src_lane, VirtualWaveSize); } private: // Changes inclusive scan results to exclusive scan results template ROCPRIM_DEVICE ROCPRIM_INLINE - void to_exclusive(T inclusive_input, T& exclusive_output, T init, - BinaryFunction scan_op) + void to_exclusive(T inclusive_input, T& exclusive_output, T init, BinaryFunction scan_op) { // include init value in scan results exclusive_output = scan_op(init, inclusive_input); // get exclusive results - exclusive_output = warp_shuffle_up(exclusive_output, 1, WarpSize); - if(detail::logical_lane_id() == 0) + exclusive_output = warp_shuffle_up(exclusive_output, 1, VirtualWaveSize); + if(detail::logical_lane_id() == 0) { exclusive_output = init; } @@ -316,7 +322,7 @@ class warp_scan_dpp void to_exclusive(T inclusive_input, T& exclusive_output) { // shift to get exclusive results - exclusive_output = warp_shuffle_up(inclusive_input, 1, WarpSize); + exclusive_output = warp_shuffle_up(inclusive_input, 1, VirtualWaveSize); } }; diff --git a/rocprim/include/rocprim/warp/detail/warp_scan_shared_mem.hpp b/rocprim/include/rocprim/warp/detail/warp_scan_shared_mem.hpp index 4ae09ea7b..9b5a2215d 100644 --- a/rocprim/include/rocprim/warp/detail/warp_scan_shared_mem.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_scan_shared_mem.hpp @@ -34,16 +34,14 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class T, - unsigned int WarpSize -> +template class warp_scan_shared_mem { struct storage_type_ { - T threads[WarpSize]; + T threads[VirtualWaveSize]; }; + public: ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_WITH_PUSH using storage_type = detail::raw_storage; @@ -51,16 +49,15 @@ class warp_scan_shared_mem template ROCPRIM_DEVICE ROCPRIM_INLINE - void inclusive_scan(T input, T& output, - storage_type& storage, BinaryFunction scan_op) + void inclusive_scan(T input, T& output, storage_type& storage, BinaryFunction scan_op) { - const unsigned int lid = detail::logical_lane_id(); - storage_type_& storage_ = storage.get(); + const unsigned int lid = detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); - T me = input; + T me = input; storage_.threads[lid] = me; ::rocprim::wave_barrier(); - for(unsigned int i = 1; i < WarpSize; i *= 2) + for(unsigned int i = 1; i < VirtualWaveSize; i *= 2) { const bool do_op = lid >= i; if(do_op) @@ -82,13 +79,13 @@ class warp_scan_shared_mem ROCPRIM_DEVICE ROCPRIM_INLINE void inclusive_scan(T input, T& output, storage_type& storage, BinaryFunction scan_op, T init) { - const unsigned int lid = detail::logical_lane_id(); + const unsigned int lid = detail::logical_lane_id(); storage_type_& storage_ = storage.get(); T me = input; storage_.threads[lid] = me; ::rocprim::wave_barrier(); - for(unsigned int i = 1; i < WarpSize; i *= 2) + for(unsigned int i = 1; i < VirtualWaveSize; i *= 2) { const bool do_op = lid >= i; if(do_op) @@ -103,18 +100,20 @@ class warp_scan_shared_mem } ::rocprim::wave_barrier(); } - output = scan_op(init, me); - storage_.threads[lid] = output; + + // Apply the initial value. Do not write the result + // of applying the initial value to memory. + output = scan_op(init, me); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void inclusive_scan(T input, T& output, T& reduction, - storage_type& storage, BinaryFunction scan_op) + void inclusive_scan( + T input, T& output, T& reduction, storage_type& storage, BinaryFunction scan_op) { storage_type_& storage_ = storage.get(); inclusive_scan(input, output, storage, scan_op); - reduction = storage_.threads[WarpSize - 1]; + reduction = storage_.threads[VirtualWaveSize - 1]; } template @@ -125,13 +124,12 @@ class warp_scan_shared_mem storage_type_& storage_ = storage.get(); inclusive_scan(input, output, storage, scan_op, init); ::rocprim::wave_barrier(); - reduction = storage_.threads[WarpSize - 1]; + reduction = storage_.threads[VirtualWaveSize - 1]; } template ROCPRIM_DEVICE ROCPRIM_INLINE - void exclusive_scan(T input, T& output, T init, - storage_type& storage, BinaryFunction scan_op) + void exclusive_scan(T input, T& output, T init, storage_type& storage, BinaryFunction scan_op) { inclusive_scan(input, output, storage, scan_op); to_exclusive(output, init, storage, scan_op); @@ -139,37 +137,41 @@ class warp_scan_shared_mem template ROCPRIM_DEVICE ROCPRIM_INLINE - void exclusive_scan(T input, T& output, - storage_type& storage, BinaryFunction scan_op) + void exclusive_scan(T input, T& output, storage_type& storage, BinaryFunction scan_op) { inclusive_scan(input, output, storage, scan_op); to_exclusive(output, storage); } template - ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan( + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan( T input, T& output, storage_type& storage, T& reduction, BinaryFunction scan_op) { inclusive_scan(input, output, storage, scan_op); - reduction = storage.get().threads[WarpSize - 1]; + reduction = storage.get().threads[VirtualWaveSize - 1]; to_exclusive(output, storage); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void exclusive_scan(T input, T& output, T init, T& reduction, - storage_type& storage, BinaryFunction scan_op) + void exclusive_scan( + T input, T& output, T init, T& reduction, storage_type& storage, BinaryFunction scan_op) { storage_type_& storage_ = storage.get(); inclusive_scan(input, output, storage, scan_op); - reduction = storage_.threads[WarpSize - 1]; + reduction = storage_.threads[VirtualWaveSize - 1]; to_exclusive(output, init, storage, scan_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, T init, - storage_type& storage, BinaryFunction scan_op) + void scan(T input, + T& inclusive_output, + T& exclusive_output, + T init, + storage_type& storage, + BinaryFunction scan_op) { inclusive_scan(input, inclusive_output, storage, scan_op); to_exclusive(exclusive_output, init, storage, scan_op); @@ -177,8 +179,11 @@ class warp_scan_shared_mem template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, - storage_type& storage, BinaryFunction scan_op) + void scan(T input, + T& inclusive_output, + T& exclusive_output, + storage_type& storage, + BinaryFunction scan_op) { inclusive_scan(input, inclusive_output, storage, scan_op); to_exclusive(exclusive_output, storage); @@ -186,12 +191,17 @@ class warp_scan_shared_mem template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction, - storage_type& storage, BinaryFunction scan_op) + void scan(T input, + T& inclusive_output, + T& exclusive_output, + T init, + T& reduction, + storage_type& storage, + BinaryFunction scan_op) { storage_type_& storage_ = storage.get(); inclusive_scan(input, inclusive_output, storage, scan_op); - reduction = storage_.threads[WarpSize - 1]; + reduction = storage_.threads[VirtualWaveSize - 1]; ::rocprim::wave_barrier(); to_exclusive(exclusive_output, init, storage, scan_op); } @@ -200,7 +210,7 @@ class warp_scan_shared_mem T broadcast(T input, const unsigned int src_lane, storage_type& storage) { storage_type_& storage_ = storage.get(); - if(src_lane == detail::logical_lane_id()) + if(src_lane == detail::logical_lane_id()) { storage_.threads[src_lane] = input; } @@ -208,24 +218,15 @@ class warp_scan_shared_mem return storage_.threads[src_lane]; } -protected: - [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE void - to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) - { - (void) inclusive_input; - return to_exclusive(exclusive_output, storage); - } - private: // Calculate exclusive results base on inclusive scan results in storage.threads[]. template ROCPRIM_DEVICE ROCPRIM_INLINE - void to_exclusive(T& exclusive_output, T init, - storage_type& storage, BinaryFunction scan_op) + void to_exclusive(T& exclusive_output, T init, storage_type& storage, BinaryFunction scan_op) { - const unsigned int lid = detail::logical_lane_id(); - storage_type_& storage_ = storage.get(); - exclusive_output = init; + const unsigned int lid = detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); + exclusive_output = init; if(lid != 0) { exclusive_output = scan_op(init, storage_.threads[lid - 1]); @@ -235,8 +236,8 @@ class warp_scan_shared_mem ROCPRIM_DEVICE ROCPRIM_INLINE void to_exclusive(T& exclusive_output, storage_type& storage) { - const unsigned int lid = detail::logical_lane_id(); - storage_type_& storage_ = storage.get(); + const unsigned int lid = detail::logical_lane_id(); + storage_type_& storage_ = storage.get(); if(lid != 0) { exclusive_output = storage_.threads[lid - 1]; diff --git a/rocprim/include/rocprim/warp/detail/warp_scan_shuffle.hpp b/rocprim/include/rocprim/warp/detail/warp_scan_shuffle.hpp index 95a58a5b2..66767d220 100644 --- a/rocprim/include/rocprim/warp/detail/warp_scan_shuffle.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_scan_shuffle.hpp @@ -27,6 +27,7 @@ #include "../../detail/various.hpp" #include "../../intrinsics.hpp" +#include "../../intrinsics/warp_shuffle.hpp" #include "../../types.hpp" BEGIN_ROCPRIM_NAMESPACE @@ -34,14 +35,11 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class T, - unsigned int WarpSize -> +template class warp_scan_shuffle { public: - static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2"); + static_assert(detail::is_power_of_two(VirtualWaveSize), "VirtualWaveSize must be power of 2"); using storage_type = detail::empty_storage_type; @@ -51,22 +49,22 @@ class warp_scan_shuffle { output = input; - T value; - const unsigned int id = detail::logical_lane_id(); + T value; + const unsigned int id = detail::logical_lane_id(); ROCPRIM_UNROLL - for(unsigned int offset = 1; offset < WarpSize; offset *= 2) + for(unsigned int offset = 1; offset < VirtualWaveSize; offset *= 2) { - value = warp_shuffle_up(output, offset, WarpSize); - if(id >= offset) output = scan_op(value, output); + value = warp_shuffle_up(output, offset, VirtualWaveSize); + if(id >= offset) + output = scan_op(value, output); } } template ROCPRIM_DEVICE ROCPRIM_INLINE - void inclusive_scan(T input, T& output, - storage_type& storage, BinaryFunction scan_op) + void inclusive_scan(T input, T& output, storage_type& storage, BinaryFunction scan_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning inclusive_scan(input, output, scan_op); } @@ -77,11 +75,11 @@ class warp_scan_shuffle output = input; T value; - const unsigned int id = detail::logical_lane_id(); + const unsigned int id = detail::logical_lane_id(); ROCPRIM_UNROLL - for(unsigned int offset = 1; offset < WarpSize; offset *= 2) + for(unsigned int offset = 1; offset < VirtualWaveSize; offset *= 2) { - value = warp_shuffle_up(output, offset, WarpSize); + value = warp_shuffle_up(output, offset, VirtualWaveSize); if(id >= offset) output = scan_op(value, output); } @@ -99,20 +97,19 @@ class warp_scan_shuffle template ROCPRIM_DEVICE ROCPRIM_INLINE - void inclusive_scan(T input, T& output, T& reduction, - BinaryFunction scan_op) + void inclusive_scan(T input, T& output, T& reduction, BinaryFunction scan_op) { inclusive_scan(input, output, scan_op); - // Broadcast value from the last thread in warp - reduction = warp_shuffle(output, WarpSize-1, WarpSize); + // Broadcast value from the last thread in the warp + reduction = warp_shuffle(output, VirtualWaveSize - 1, VirtualWaveSize); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void inclusive_scan(T input, T& output, T& reduction, - storage_type& storage, BinaryFunction scan_op) + void inclusive_scan( + T input, T& output, T& reduction, storage_type& storage, BinaryFunction scan_op) { - (void) storage; + (void)storage; inclusive_scan(input, output, reduction, scan_op); } @@ -121,10 +118,10 @@ class warp_scan_shuffle void inclusive_scan(T input, T& output, T& reduction, BinaryFunction scan_op, T init) { inclusive_scan(input, output, scan_op); + // Broadcast value from the last thread in the warp + reduction = warp_shuffle(output, VirtualWaveSize - 1, VirtualWaveSize); // Include init value in scan results output = scan_op(init, output); - // Broadcast value from the last thread in warp - reduction = warp_shuffle(output, WarpSize - 1, WarpSize); } template @@ -147,60 +144,57 @@ class warp_scan_shuffle template ROCPRIM_DEVICE ROCPRIM_INLINE - void exclusive_scan(T input, T& output, T init, - storage_type& storage, BinaryFunction scan_op) + void exclusive_scan(T input, T& output, T init, storage_type& storage, BinaryFunction scan_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning exclusive_scan(input, output, init, scan_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void exclusive_scan(T input, T& output, - storage_type& storage, BinaryFunction scan_op) + void exclusive_scan(T input, T& output, storage_type& storage, BinaryFunction scan_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning inclusive_scan(input, output, scan_op); // Convert inclusive scan result to exclusive to_exclusive(output, output); } template - ROCPRIM_DEVICE ROCPRIM_INLINE void exclusive_scan( + ROCPRIM_DEVICE ROCPRIM_INLINE + void exclusive_scan( T input, T& output, storage_type& /*storage*/, T& reduction, BinaryFunction scan_op) { inclusive_scan(input, output, scan_op); - // Broadcast value from the last thread in warp - reduction = warp_shuffle(output, WarpSize - 1, WarpSize); + // Broadcast value from the last thread in the warp + reduction = warp_shuffle(output, VirtualWaveSize - 1, VirtualWaveSize); // Convert inclusive scan result to exclusive to_exclusive(output, output); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void exclusive_scan(T input, T& output, T init, T& reduction, - BinaryFunction scan_op) + void exclusive_scan(T input, T& output, T init, T& reduction, BinaryFunction scan_op) { inclusive_scan(input, output, scan_op); - // Broadcast value from the last thread in warp - reduction = warp_shuffle(output, WarpSize-1, WarpSize); + // Broadcast value from the last thread in the warp + reduction = warp_shuffle(output, VirtualWaveSize - 1, VirtualWaveSize); // Convert inclusive scan result to exclusive to_exclusive(output, output, init, scan_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void exclusive_scan(T input, T& output, T init, T& reduction, - storage_type& storage, BinaryFunction scan_op) + void exclusive_scan( + T input, T& output, T init, T& reduction, storage_type& storage, BinaryFunction scan_op) { - (void) storage; + (void)storage; exclusive_scan(input, output, init, reduction, scan_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, T init, - BinaryFunction scan_op) + void scan(T input, T& inclusive_output, T& exclusive_output, T init, BinaryFunction scan_op) { inclusive_scan(input, inclusive_output, scan_op); // Convert inclusive scan result to exclusive @@ -209,19 +203,26 @@ class warp_scan_shuffle template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, T init, - storage_type& storage, BinaryFunction scan_op) + void scan(T input, + T& inclusive_output, + T& exclusive_output, + T init, + storage_type& storage, + BinaryFunction scan_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning scan(input, inclusive_output, exclusive_output, init, scan_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, - storage_type& storage, BinaryFunction scan_op) + void scan(T input, + T& inclusive_output, + T& exclusive_output, + storage_type& storage, + BinaryFunction scan_op) { - (void) storage; // disables unused parameter warning + (void)storage; // disables unused parameter warning inclusive_scan(input, inclusive_output, scan_op); // Convert inclusive scan result to exclusive to_exclusive(inclusive_output, exclusive_output); @@ -229,52 +230,52 @@ class warp_scan_shuffle template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction, + void scan(T input, + T& inclusive_output, + T& exclusive_output, + T init, + T& reduction, BinaryFunction scan_op) { inclusive_scan(input, inclusive_output, scan_op); - // Broadcast value from the last thread in warp - reduction = warp_shuffle(inclusive_output, WarpSize-1, WarpSize); + // Broadcast value from the last thread in the warp + reduction = warp_shuffle(inclusive_output, VirtualWaveSize - 1, VirtualWaveSize); // Convert inclusive scan result to exclusive to_exclusive(inclusive_output, exclusive_output, init, scan_op); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void scan(T input, T& inclusive_output, T& exclusive_output, T init, T& reduction, - storage_type& storage, BinaryFunction scan_op) + void scan(T input, + T& inclusive_output, + T& exclusive_output, + T init, + T& reduction, + storage_type& storage, + BinaryFunction scan_op) { - (void) storage; + (void)storage; scan(input, inclusive_output, exclusive_output, init, reduction, scan_op); } ROCPRIM_DEVICE ROCPRIM_INLINE T broadcast(T input, const unsigned int src_lane, storage_type& storage) { - (void) storage; - return warp_shuffle(input, src_lane, WarpSize); - } - -protected: - [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE void - to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) - { - (void) storage; - return to_exclusive(inclusive_input, exclusive_output); + (void)storage; + return warp_shuffle(input, src_lane, VirtualWaveSize); } private: // Changes inclusive scan results to exclusive scan results template ROCPRIM_DEVICE ROCPRIM_INLINE - void to_exclusive(T inclusive_input, T& exclusive_output, T init, - BinaryFunction scan_op) + void to_exclusive(T inclusive_input, T& exclusive_output, T init, BinaryFunction scan_op) { // include init value in scan results exclusive_output = scan_op(init, inclusive_input); // get exclusive results - exclusive_output = warp_shuffle_up(exclusive_output, 1, WarpSize); - if(detail::logical_lane_id() == 0) + exclusive_output = warp_shuffle_up(exclusive_output, 1, VirtualWaveSize); + if(detail::logical_lane_id() == 0) { exclusive_output = init; } @@ -284,7 +285,7 @@ class warp_scan_shuffle void to_exclusive(T inclusive_input, T& exclusive_output) { // shift to get exclusive results - exclusive_output = warp_shuffle_up(inclusive_input, 1, WarpSize); + exclusive_output = warp_shuffle_up(inclusive_input, 1, VirtualWaveSize); } }; diff --git a/rocprim/include/rocprim/warp/detail/warp_segment_bounds.hpp b/rocprim/include/rocprim/warp/detail/warp_segment_bounds.hpp index db267bbe6..13ba653cc 100644 --- a/rocprim/include/rocprim/warp/detail/warp_segment_bounds.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_segment_bounds.hpp @@ -24,7 +24,10 @@ #include #include "../../config.hpp" -#include "../../intrinsics.hpp" +#include "../../intrinsics/arch.hpp" +#include "../../intrinsics/thread.hpp" +#include "../../intrinsics/warp.hpp" +#include "../../types.hpp" BEGIN_ROCPRIM_NAMESPACE @@ -32,10 +35,13 @@ namespace detail { // Returns logical warp id of the last thread in thread's segment -template +template ROCPRIM_DEVICE ROCPRIM_INLINE auto last_in_warp_segment(Flag flag) -> - typename std::enable_if<(WarpSize <= arch::wavefront::min_size()), unsigned int>::type + typename std::enable_if<(WarpSize <= arch::wavefront::max_size()), unsigned int>::type { // Get flags (now every thread know where the flags are) lane_mask_type warp_flags = ::rocprim::ballot(flag); @@ -53,11 +59,24 @@ auto last_in_warp_segment(Flag flag) -> // Make sure last item in logical warp is marked as a tail warp_flags |= lane_mask_type(1) << (WarpSize - 1U); // Calculate logical lane id of the last valid value in the segment -#if ROCPRIM_WAVEFRONT_SIZE == 32 - return ::__ffs(warp_flags) - 1; -#else - return ::__ffsll(warp_flags) - 1; -#endif + + if constexpr(Target == arch::wavefront::target::size32) + { + // The static_cast prevents "error: call to '__ffs' is ambiguous" + return ::__ffs(static_cast(warp_flags)) - 1; + } + else if constexpr(Target == arch::wavefront::target::size64) + { + // The static_cast prevents "error: call to '__ffsll' is ambiguous" + return ::__ffsll(static_cast(warp_flags)) - 1; + } + else + { + // Dynamic case, used for SPIR-V. + return arch::wavefront::size() == ROCPRIM_WARP_SIZE_32 + ? ::__ffs(static_cast(warp_flags)) - 1 + : ::__ffsll(static_cast(warp_flags)) - 1; + } } } // end namespace detail diff --git a/rocprim/include/rocprim/warp/detail/warp_sort_shuffle.hpp b/rocprim/include/rocprim/warp/detail/warp_sort_shuffle.hpp index d4bd2cee7..e585c636d 100644 --- a/rocprim/include/rocprim/warp/detail/warp_sort_shuffle.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_sort_shuffle.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -34,12 +34,13 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template +template class warp_sort_shuffle { private: template - ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if warp)>::type + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if warp)>::type swap(Key& k, V& v, bool dir, BinaryFunction compare_function) { (void)k; @@ -49,21 +50,25 @@ class warp_sort_shuffle } template - ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(WarpSize > warp)>::type + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(VirtualWaveSize > warp)>::type swap(Key& k, V& v, bool dir, BinaryFunction compare_function) { - Key k1 = warp_swizzle_shuffle(k, xor_mask, WarpSize); + Key k1 = warp_swizzle_shuffle(k, xor_mask, VirtualWaveSize); bool swap = compare_function(dir ? k : k1, dir ? k1 : k); if(swap) { k = k1; - v = warp_swizzle_shuffle(v, xor_mask, WarpSize); + v = warp_swizzle_shuffle(v, xor_mask, VirtualWaveSize); } } template - ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if warp)>::type swap( - Key (&k)[ItemsPerThread], V (&v)[ItemsPerThread], bool dir, BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if warp)>::type swap(Key (&k)[ItemsPerThread], + V (&v)[ItemsPerThread], + bool dir, + BinaryFunction compare_function) { (void)k; (void)v; @@ -72,25 +77,29 @@ class warp_sort_shuffle } template - ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(WarpSize > warp)>::type swap( - Key (&k)[ItemsPerThread], V (&v)[ItemsPerThread], bool dir, BinaryFunction compare_function) + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(VirtualWaveSize > warp)>::type swap(Key (&k)[ItemsPerThread], + V (&v)[ItemsPerThread], + bool dir, + BinaryFunction compare_function) { Key k1[ItemsPerThread]; ROCPRIM_UNROLL for(unsigned int item = 0; item < ItemsPerThread; item++) { - k1[item] = warp_swizzle_shuffle(k[item], xor_mask, WarpSize); + k1[item] = warp_swizzle_shuffle(k[item], xor_mask, VirtualWaveSize); bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]); if(swap) { k[item] = k1[item]; - v[item] = warp_swizzle_shuffle(v[item], xor_mask, WarpSize); + v[item] = warp_swizzle_shuffle(v[item], xor_mask, VirtualWaveSize); } } } template - ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if warp)>::type + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if warp)>::type swap(Key& k, bool dir, BinaryFunction compare_function) { (void)k; @@ -99,10 +108,11 @@ class warp_sort_shuffle } template - ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(WarpSize > warp)>::type + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(VirtualWaveSize > warp)>::type swap(Key& k, bool dir, BinaryFunction compare_function) { - Key k1 = warp_swizzle_shuffle(k, xor_mask, WarpSize); + Key k1 = warp_swizzle_shuffle(k, xor_mask, VirtualWaveSize); bool swap = compare_function(dir ? k : k1, dir ? k1 : k); if(swap) { @@ -111,7 +121,8 @@ class warp_sort_shuffle } template - ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if warp)>::type + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if warp)>::type swap(Key (&k)[ItemsPerThread], bool dir, BinaryFunction compare_function) { (void)k; @@ -120,14 +131,15 @@ class warp_sort_shuffle } template - ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(WarpSize > warp)>::type + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(VirtualWaveSize > warp)>::type swap(Key (&k)[ItemsPerThread], bool dir, BinaryFunction compare_function) { Key k1[ItemsPerThread]; ROCPRIM_UNROLL for(unsigned int item = 0; item < ItemsPerThread; item++) { - k1[item] = warp_swizzle_shuffle(k[item], xor_mask, WarpSize); + k1[item] = warp_swizzle_shuffle(k[item], xor_mask, VirtualWaveSize); bool swap = compare_function(dir ? k[item] : k1[item], dir ? k1[item] : k[item]); if(swap) { @@ -207,7 +219,8 @@ class warp_sort_shuffle } template - ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if<(WarpSize > warp)>::type + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if<(VirtualWaveSize > warp)>::type thread_merge(bool dir, BinaryFunction compare_function, KeyValue&... kv) { ROCPRIM_UNROLL @@ -218,7 +231,8 @@ class warp_sort_shuffle } template - ROCPRIM_DEVICE ROCPRIM_INLINE typename std::enable_if warp)>::type + ROCPRIM_DEVICE ROCPRIM_INLINE + typename std::enable_if warp)>::type thread_merge(bool /*dir*/, BinaryFunction /*compare_function*/, KeyValue&... /*kv*/) {} @@ -229,7 +243,7 @@ class warp_sort_shuffle static_assert(sizeof...(KeyValue) < 3, "KeyValue parameter pack can 1 or 2 elements (key, or key and value)"); - const unsigned int id = detail::logical_lane_id(); + const unsigned int id = detail::logical_lane_id(); swap<2, 1>(kv..., get_bit(id, 1) != get_bit(id, 0), compare_function); @@ -268,7 +282,7 @@ class warp_sort_shuffle static_assert(detail::is_power_of_two(ItemsPerThread), "ItemsPerThread must be power of 2"); - const unsigned int id = detail::logical_lane_id(); + const unsigned int id = detail::logical_lane_id(); thread_sort(get_bit(id, 0) != 0, compare_function, kv...); @@ -307,7 +321,7 @@ class warp_sort_shuffle } public: - static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2"); + static_assert(detail::is_power_of_two(VirtualWaveSize), "VirtualWaveSize must be power of 2"); using storage_type = ::rocprim::detail::empty_storage_type; @@ -355,9 +369,9 @@ class warp_sort_shuffle sort(Key& thread_key, Value& thread_value, BinaryFunction compare_function) { // Instead of passing large values between lanes we pass indices and gather values after sorting. - unsigned int v = detail::logical_lane_id(); + unsigned int v = detail::logical_lane_id(); bitonic_sort(compare_function, thread_key, v); - thread_value = warp_shuffle(thread_value, v, WarpSize); + thread_value = warp_shuffle(thread_value, v, VirtualWaveSize); } template @@ -390,7 +404,7 @@ class warp_sort_shuffle ROCPRIM_UNROLL for(unsigned int item = 0; item < ItemsPerThread; item++) { - v[item] = ItemsPerThread * detail::logical_lane_id() + item; + v[item] = ItemsPerThread * detail::logical_lane_id() + item; } bitonic_sort(compare_function, thread_keys, v); @@ -408,7 +422,8 @@ class warp_sort_shuffle ROCPRIM_UNROLL for(unsigned src_item = 0; src_item < ItemsPerThread; ++src_item) { - V temp = warp_shuffle(copy[src_item], v[dst_item] / ItemsPerThread, WarpSize); + V temp + = warp_shuffle(copy[src_item], v[dst_item] / ItemsPerThread, VirtualWaveSize); if(v[dst_item] % ItemsPerThread == src_item) thread_values[dst_item] = temp; } diff --git a/rocprim/include/rocprim/warp/detail/warp_sort_stable.hpp b/rocprim/include/rocprim/warp/detail/warp_sort_stable.hpp index a720ea07f..8dd958821 100644 --- a/rocprim/include/rocprim/warp/detail/warp_sort_stable.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_sort_stable.hpp @@ -39,7 +39,7 @@ namespace detail template class warp_sort_stable @@ -122,12 +122,12 @@ class warp_sort_stable const auto lane = lane_id(); const auto warp = warp_id(); - const auto warp_offset = warp * ItemsPerThread * arch::wavefront::min_size(); + const auto warp_offset = warp * ItemsPerThread * arch::wavefront::size(); const auto warp_input_size = warp_offset > input_size ? 0 : input_size - warp_offset; const auto shared_keys = &storage.keys[warp_offset]; ROCPRIM_UNROLL - for(auto partition_size = 1u; partition_size < WarpSize; partition_size <<= 1u) + for(auto partition_size = 1u; partition_size < VirtualWaveSize; partition_size <<= 1u) { ROCPRIM_UNROLL for(auto i = 0u; i < ItemsPerThread; ++i) @@ -181,13 +181,13 @@ class warp_sort_stable const auto lane = lane_id(); const auto warp = warp_id(); - const auto warp_offset = warp * ItemsPerThread * arch::wavefront::min_size(); + const auto warp_offset = warp * ItemsPerThread * arch::wavefront::size(); const auto warp_input_size = warp_offset > input_size ? 0 : input_size - warp_offset; const auto shared_keys = &storage.keys[warp_offset]; const auto shared_values = &storage.values[warp_offset]; ROCPRIM_UNROLL - for(auto partition_size = 1u; partition_size < WarpSize; partition_size <<= 1u) + for(auto partition_size = 1u; partition_size < VirtualWaveSize; partition_size <<= 1u) { ROCPRIM_UNROLL for(auto i = 0u; i < ItemsPerThread; ++i) @@ -237,7 +237,7 @@ class warp_sort_stable } public: - static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2"); + static_assert(detail::is_power_of_two(VirtualWaveSize), "VirtualWaveSize must be power of 2"); ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_WITH_PUSH using storage_type = raw_storage; @@ -350,8 +350,8 @@ class warp_sort_stable } }; -template -class warp_sort_stable +template +class warp_sort_stable { private: constexpr static unsigned items_per_thread = 1; @@ -422,7 +422,7 @@ class warp_sort_stable } public: - static_assert(detail::is_power_of_two(WarpSize), "WarpSize must be power of 2"); + static_assert(detail::is_power_of_two(VirtualWaveSize), "VirtualWaveSize must be power of 2"); using storage_type = empty_storage_type; @@ -430,7 +430,7 @@ class warp_sort_stable ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key& thread_key, BinaryFunction compare_function) { ROCPRIM_UNROLL - for(auto i = 1u; i < WarpSize; i <<= 1u) + for(auto i = 1u; i < VirtualWaveSize; i <<= 1u) { const auto thread_rank = merge_rank(i, thread_key, compare_function); thread_key = warp_permute(thread_key, thread_rank); @@ -477,11 +477,11 @@ class warp_sort_stable { (void)storage; - const auto warp_offset = warp_id() * arch::wavefront::min_size(); + const auto warp_offset = warp_id() * arch::wavefront::size(); const auto warp_input_size = warp_offset > input_size ? 0 : input_size - warp_offset; ROCPRIM_UNROLL - for(auto i = 1u; i < WarpSize; i <<= 1u) + for(auto i = 1u; i < VirtualWaveSize; i <<= 1u) { const auto thread_rank = merge_rank(i, thread_key, compare_function, warp_input_size); @@ -494,7 +494,7 @@ class warp_sort_stable sort(Key& thread_key, V& thread_value, BinaryFunction compare_function) { ROCPRIM_UNROLL - for(auto i = 1u; i < WarpSize; i <<= 1u) + for(auto i = 1u; i < VirtualWaveSize; i <<= 1u) { const auto thread_rank = merge_rank(i, thread_key, compare_function); thread_key = warp_permute(thread_key, thread_rank); @@ -562,11 +562,11 @@ class warp_sort_stable { (void)storage; - const auto warp_offset = warp_id() * arch::wavefront::min_size(); + const auto warp_offset = warp_id() * arch::wavefront::size(); const auto warp_input_size = warp_offset > input_size ? 0 : input_size - warp_offset; ROCPRIM_UNROLL - for(auto i = 1u; i < WarpSize; i <<= 1u) + for(auto i = 1u; i < VirtualWaveSize; i <<= 1u) { const auto thread_rank = merge_rank(i, thread_key, compare_function, warp_input_size); diff --git a/rocprim/include/rocprim/warp/warp_exchange.hpp b/rocprim/include/rocprim/warp/warp_exchange.hpp index 557cea70f..f04367533 100644 --- a/rocprim/include/rocprim/warp/warp_exchange.hpp +++ b/rocprim/include/rocprim/warp/warp_exchange.hpp @@ -45,7 +45,7 @@ BEGIN_ROCPRIM_NAMESPACE /// /// \tparam T the input type. /// \tparam ItemsPerThread the number of items contributed by each thread. -/// \tparam WarpSize the number of threads in a warp. +/// \tparam VirtualWaveSize the number of threads in a warp. /// /// \par Overview /// * The \p warp_exchange class supports the following rearrangement methods: @@ -80,40 +80,48 @@ BEGIN_ROCPRIM_NAMESPACE /// \endparblock template + unsigned int VirtualWaveSize = ::rocprim::arch::wavefront::min_size(), + ::rocprim::arch::wavefront::target TargetWaveSize + = ::rocprim::arch::wavefront::get_target()> class warp_exchange { - static_assert(::rocprim::detail::is_power_of_two(WarpSize), + static_assert(::rocprim::detail::is_power_of_two(VirtualWaveSize), "Logical warp size must be a power of two."); - ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT( - WarpSize <= ::rocprim::arch::wavefront::min_size(), - "Logical warp size cannot be larger than physical warp size."); +public: + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE warp_exchange() + { + detail::check_virtual_wave_size(); + } + +private: struct storage_type_ { - uninitialized_array buffer; + uninitialized_array buffer; }; template - ROCPRIM_DEVICE ROCPRIM_INLINE void Foreach(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread], - const int xor_bit_set) + ROCPRIM_DEVICE ROCPRIM_INLINE + void Foreach(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const int xor_bit_set) { // To prevent double work for IdX and IdX + NumEntries if(NumEntries != 0 && (IdX / NumEntries) % 2 == 0) { const T send_val = (xor_bit_set ? input[IdX] : input[IdX + NumEntries]); const T recv_val - = ::rocprim::detail::warp_swizzle_shuffle(send_val, NumEntries, WarpSize); + = ::rocprim::detail::warp_swizzle_shuffle(send_val, NumEntries, VirtualWaveSize); (xor_bit_set ? output[IdX] : output[IdX + NumEntries]) = recv_val; } } template - ROCPRIM_DEVICE ROCPRIM_INLINE void Foreach(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread], - const std::integer_sequence, - const bool xor_bit_set) + ROCPRIM_DEVICE ROCPRIM_INLINE + void Foreach(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const std::integer_sequence, + const bool xor_bit_set) { // To create a static inner loop that executes the code with // the values [0, 1, ..., ItemsPerThread-1, ItemsPerThread] as IdX @@ -122,10 +130,11 @@ class warp_exchange } template - ROCPRIM_DEVICE ROCPRIM_INLINE void TransposeImpl(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread], - const unsigned int lane_id, - const std::integer_sequence) + ROCPRIM_DEVICE ROCPRIM_INLINE + void TransposeImpl(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const unsigned int lane_id, + const std::integer_sequence) { // To create a static outer loop that executes the code with // the values [ItemsPerThread/2, ItemsPerThread/4, ..., 1, 0] as NumEntries @@ -139,9 +148,10 @@ class warp_exchange } template - ROCPRIM_DEVICE ROCPRIM_INLINE void Transpose(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread], - const unsigned int lane_id) + ROCPRIM_DEVICE ROCPRIM_INLINE + void Transpose(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const unsigned int lane_id) { constexpr unsigned int n_iter = rocprim::Log2::VALUE; TransposeImpl(input, @@ -183,21 +193,21 @@ class warp_exchange // Conditions for blocked to striped and striped to blocked struct conditions { - static constexpr bool is_equal_size = WarpSize == ItemsPerThread; + static constexpr bool is_equal_size = VirtualWaveSize == ItemsPerThread; static constexpr bool is_quad_compatible_bs = ItemsPerThread % ROCPRIM_QUAD_SIZE == 0 - && ItemsPerThread % (WarpSize / ROCPRIM_QUAD_SIZE) == 0; + && ItemsPerThread % (VirtualWaveSize / ROCPRIM_QUAD_SIZE) == 0; static constexpr bool is_quad_compatible_sb = ItemsPerThread % ROCPRIM_QUAD_SIZE == 0 - && ItemsPerThread % (WarpSize / ROCPRIM_QUAD_SIZE) == 0 + && ItemsPerThread % (VirtualWaveSize / ROCPRIM_QUAD_SIZE) == 0 // this config is not performant for the DPP quad_perm implementation - && !(WarpSize == 64 && ItemsPerThread == 32); + && !(VirtualWaveSize == 64 && ItemsPerThread == 32); - static constexpr bool warp_divide_items = ItemsPerThread % WarpSize == 0; + static constexpr bool warp_divide_items = ItemsPerThread % VirtualWaveSize == 0; - static constexpr bool items_divide_warp = WarpSize % ItemsPerThread == 0; + static constexpr bool items_divide_warp = VirtualWaveSize % ItemsPerThread == 0; }; enum class ImplementationType @@ -236,8 +246,8 @@ class warp_exchange template ROCPRIM_DEVICE ROCPRIM_INLINE void rearrange_items(unsigned int flat_id, - U (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + U (&input)[ItemsPerThread], + U (&output)[ItemsPerThread]) { constexpr unsigned int ipt_div_width = ItemsPerThread / Width; @@ -266,11 +276,9 @@ class warp_exchange typename std::enable_if::value == ImplementationType::EqualSize>::type blocked_to_striped_shuffle_impl(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + U (&output)[ItemsPerThread]) { - static constexpr bool IS_ARCH_WARP = WarpSize == ::rocprim::arch::wavefront::min_size(); - const unsigned int flat_lane_id = ::rocprim::detail::logical_lane_id(); - const unsigned int lane_id = IS_ARCH_WARP ? flat_lane_id : (flat_lane_id % WarpSize); + const unsigned int lane_id = ::rocprim::detail::logical_lane_id(); T temp[ItemsPerThread]; ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) @@ -287,7 +295,7 @@ class warp_exchange // Case 2: Quad compatible // Works only when ItemsPerThread % ROCPRIM_QUAD_SIZE == 0 && - // ItemsPerThread % (WarpSize / ROCPRIM_QUAD_SIZE) == 0 + // ItemsPerThread % (VirtualWaveSize / ROCPRIM_QUAD_SIZE) == 0 // // FIRST PART: Going from blocked to striped at the quad level using DPP quad_perm and // item rearrangements @@ -327,12 +335,12 @@ class warp_exchange typename std::enable_if::value == ImplementationType::QuadCompatible>::type blocked_to_striped_shuffle_impl(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + U (&output)[ItemsPerThread]) { - constexpr unsigned int NUM_QUADS = WarpSize / ROCPRIM_QUAD_SIZE; + constexpr unsigned int NUM_QUADS = VirtualWaveSize / ROCPRIM_QUAD_SIZE; constexpr unsigned int IPT_DIV_QUAD_SIZE = ItemsPerThread / ROCPRIM_QUAD_SIZE; - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); U values1[ItemsPerThread]; U values2[ItemsPerThread]; @@ -390,7 +398,7 @@ class warp_exchange quad_perm_index--; } - if ROCPRIM_IF_CONSTEXPR(WarpSize > ROCPRIM_QUAD_SIZE) + if constexpr(VirtualWaveSize > ROCPRIM_QUAD_SIZE) { // First warp rotation ROCPRIM_UNROLL @@ -400,7 +408,7 @@ class warp_exchange if(i_mod_quad_warp != 0) { const unsigned int total_rotation = i_mod_quad_warp * ROCPRIM_QUAD_SIZE; - values2[i] = warp_rotate_right(values2[i], total_rotation); + values2[i] = warp_rotate_right(values2[i], total_rotation); } } @@ -408,7 +416,7 @@ class warp_exchange // Second warp rotation constexpr unsigned int items_per_quad_warp = ItemsPerThread / NUM_QUADS; - unsigned int remaining_rotations = NUM_QUADS - 1; + unsigned int remaining_rotations = NUM_QUADS - 1; ROCPRIM_UNROLL for(unsigned int i = items_per_quad_warp; i < ItemsPerThread; i += items_per_quad_warp) { @@ -416,7 +424,7 @@ class warp_exchange for(unsigned int j = 0; j < items_per_quad_warp && (i + j) < ItemsPerThread; j++) { const unsigned int rotation = remaining_rotations * ROCPRIM_QUAD_SIZE; - values3[i + j] = warp_rotate_right(values3[i + j], rotation); + values3[i + j] = warp_rotate_right(values3[i + j], rotation); } remaining_rotations--; } @@ -436,12 +444,12 @@ class warp_exchange typename std::enable_if::value == ImplementationType::WarpDivideItems>::type blocked_to_striped_shuffle_impl(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + U (&output)[ItemsPerThread]) { - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); - constexpr unsigned int ipt_div_warp = ItemsPerThread / WarpSize; - U values1[ItemsPerThread]; - U values2[ItemsPerThread]; + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + constexpr unsigned int ipt_div_warp = ItemsPerThread / VirtualWaveSize; + U values1[ItemsPerThread]; + U values2[ItemsPerThread]; ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) @@ -453,24 +461,24 @@ class warp_exchange ROCPRIM_UNROLL for(unsigned int i = 1; i < ItemsPerThread; i++) { - const unsigned int rotations = i % WarpSize; + const unsigned int rotations = i % VirtualWaveSize; if(rotations != 0) { - values1[i] = warp_rotate_right(values1[i], rotations); + values1[i] = warp_rotate_right(values1[i], rotations); } } - rearrange_items(flat_id, values1, values2); + rearrange_items(flat_id, values1, values2); // Second warp rotation - unsigned int rotations = WarpSize - 1; + unsigned int rotations = VirtualWaveSize - 1; ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i += ipt_div_warp) { ROCPRIM_UNROLL for(unsigned int j = 0; j < ipt_div_warp; j++) { - values2[i] = warp_rotate_right(values2[i], rotations); + values2[i] = warp_rotate_right(values2[i], rotations); rotations--; } } @@ -488,9 +496,9 @@ class warp_exchange typename std::enable_if::value == ImplementationType::ItemsDivideWarp>::type blocked_to_striped_shuffle_impl(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + U (&output)[ItemsPerThread]) { - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); U work_array[ItemsPerThread]; ROCPRIM_UNROLL @@ -501,8 +509,8 @@ class warp_exchange { const auto value = ::rocprim::warp_shuffle( input[src_idx], - flat_id / ItemsPerThread + dst_idx * (WarpSize / ItemsPerThread), - WarpSize); + flat_id / ItemsPerThread + dst_idx * (VirtualWaveSize / ItemsPerThread), + VirtualWaveSize); if(src_idx == flat_id % ItemsPerThread) { work_array[dst_idx] = value; @@ -523,14 +531,14 @@ class warp_exchange typename std::enable_if::value == ImplementationType::EqualSize>::type striped_to_blocked_shuffle_impl(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + U (&output)[ItemsPerThread]) { blocked_to_striped_shuffle_impl(input, output); } // Case 2: Quad compatible // Works only when ItemsPerThread % ROCPRIM_QUAD_SIZE == 0 && - // ItemsPerThread % (WarpSize / ROCPRIM_QUAD_SIZE) == 0 + // ItemsPerThread % (VirtualWaveSize / ROCPRIM_QUAD_SIZE) == 0 // // The logic of this implementation is the inverse of blocked to striped. // Check comments of blocked to striped case 2 for more details. @@ -539,12 +547,12 @@ class warp_exchange typename std::enable_if::value == ImplementationType::QuadCompatible>::type striped_to_blocked_shuffle_impl(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + U (&output)[ItemsPerThread]) { - constexpr unsigned int NUM_QUADS = WarpSize / ROCPRIM_QUAD_SIZE; + constexpr unsigned int NUM_QUADS = VirtualWaveSize / ROCPRIM_QUAD_SIZE; constexpr unsigned int IPT_DIV_QUAD_SIZE = ItemsPerThread / ROCPRIM_QUAD_SIZE; - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); U values1[ItemsPerThread]; U values2[ItemsPerThread]; @@ -557,11 +565,11 @@ class warp_exchange values2[i] = input[i]; } - if ROCPRIM_IF_CONSTEXPR(WarpSize > ROCPRIM_QUAD_SIZE) + if constexpr(VirtualWaveSize > ROCPRIM_QUAD_SIZE) { // First warp rotation constexpr unsigned int items_per_quad_warp = ItemsPerThread / NUM_QUADS; - unsigned int remaining_rotations = NUM_QUADS - 1; + unsigned int remaining_rotations = NUM_QUADS - 1; ROCPRIM_UNROLL for(unsigned int i = items_per_quad_warp; i < ItemsPerThread; i += items_per_quad_warp) { @@ -569,7 +577,7 @@ class warp_exchange for(unsigned int j = 0; j < items_per_quad_warp && (i + j) < ItemsPerThread; j++) { const unsigned int rotation = remaining_rotations * ROCPRIM_QUAD_SIZE; - values1[i + j] = warp_rotate_left(values1[i + j], rotation); + values1[i + j] = warp_rotate_left(values1[i + j], rotation); } remaining_rotations--; } @@ -584,7 +592,7 @@ class warp_exchange if(i_mod_quad_warp != 0) { const unsigned int total_rotation = i_mod_quad_warp * ROCPRIM_QUAD_SIZE; - values2[i] = warp_rotate_left(values2[i], total_rotation); + values2[i] = warp_rotate_left(values2[i], total_rotation); } } } @@ -649,12 +657,12 @@ class warp_exchange typename std::enable_if::value == ImplementationType::WarpDivideItems>::type striped_to_blocked_shuffle_impl(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + U (&output)[ItemsPerThread]) { - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); - constexpr unsigned int ipt_div_warp = ItemsPerThread / WarpSize; - U values1[ItemsPerThread]; - U values2[ItemsPerThread]; + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + constexpr unsigned int ipt_div_warp = ItemsPerThread / VirtualWaveSize; + U values1[ItemsPerThread]; + U values2[ItemsPerThread]; ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) @@ -663,28 +671,28 @@ class warp_exchange } // First warp rotation - unsigned int rotations = WarpSize - 1; + unsigned int rotations = VirtualWaveSize - 1; ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i += ipt_div_warp) { ROCPRIM_UNROLL for(unsigned int j = 0; j < ipt_div_warp; j++) { - values1[i] = warp_rotate_left(values1[i], rotations); + values1[i] = warp_rotate_left(values1[i], rotations); rotations--; } } - rearrange_items(flat_id, values1, values2); + rearrange_items(flat_id, values1, values2); // Second warp rotation ROCPRIM_UNROLL for(unsigned int i = 1; i < ItemsPerThread; i++) { - const unsigned int rotations = i % WarpSize; + const unsigned int rotations = i % VirtualWaveSize; if(rotations != 0) { - values2[i] = warp_rotate_left(values2[i], rotations); + values2[i] = warp_rotate_left(values2[i], rotations); } } @@ -701,9 +709,9 @@ class warp_exchange typename std::enable_if::value == ImplementationType::ItemsDivideWarp>::type striped_to_blocked_shuffle_impl(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + U (&output)[ItemsPerThread]) { - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); U work_array[ItemsPerThread]; ROCPRIM_UNROLL @@ -712,11 +720,11 @@ class warp_exchange ROCPRIM_UNROLL for(unsigned int src_idx = 0; src_idx < ItemsPerThread; src_idx++) { - const auto value - = ::rocprim::warp_shuffle(input[src_idx], - (ItemsPerThread * flat_id + dst_idx) % WarpSize, - WarpSize); - if(flat_id / (WarpSize / ItemsPerThread) == src_idx) + const auto value = ::rocprim::warp_shuffle(input[src_idx], + (ItemsPerThread * flat_id + dst_idx) + % VirtualWaveSize, + VirtualWaveSize); + if(flat_id / (VirtualWaveSize / ItemsPerThread) == src_idx) { work_array[dst_idx] = value; } @@ -781,7 +789,7 @@ class warp_exchange U (&output)[ItemsPerThread], storage_type& storage) { - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { @@ -793,19 +801,19 @@ class warp_exchange ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { - output[i] = storage_buffer[i * WarpSize + flat_id]; + output[i] = storage_buffer[i * VirtualWaveSize + flat_id]; } } /// \brief Transposes a blocked arrangement of items to a striped arrangement /// across the warp, using warp shuffle operations. - /// Uses an optimized implementation for when WarpSize is equal to ItemsPerThread. + /// Uses an optimized implementation for when VirtualWaveSize is equal to ItemsPerThread. /// Caution: this API is experimental. Performance might not be consistent. /// One of these following conditions must be satisfied: - /// 1. WarpSize is equal to ItemsPerThread - /// 2. ItemsPerThread % ROCPRIM_QUAD_SIZE == 0 && ItemsPerThread % (WarpSize / ROCPRIM_QUAD_SIZE) == 0 - /// 3. ItemsPerThread is divisible by WarpSize - /// 4. WarpSize is divisible by ItemsPerThread + /// 1. VirtualWaveSize is equal to ItemsPerThread + /// 2. ItemsPerThread % ROCPRIM_QUAD_SIZE == 0 && ItemsPerThread % (VirtualWaveSize / ROCPRIM_QUAD_SIZE) == 0 + /// 3. ItemsPerThread is divisible by VirtualWaveSize + /// 4. VirtualWaveSize is divisible by ItemsPerThread /// /// \tparam U [inferred] the output type. /// @@ -832,8 +840,8 @@ class warp_exchange /// } /// \endcode template - ROCPRIM_DEVICE ROCPRIM_INLINE void blocked_to_striped_shuffle(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + ROCPRIM_DEVICE ROCPRIM_INLINE + void blocked_to_striped_shuffle(const T (&input)[ItemsPerThread], U (&output)[ItemsPerThread]) { static_assert( conditions::is_equal_size || conditions::is_quad_compatible_bs @@ -883,12 +891,12 @@ class warp_exchange U (&output)[ItemsPerThread], storage_type& storage) { - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { - storage.buffer.emplace(i * WarpSize + flat_id, input[i]); + storage.buffer.emplace(i * VirtualWaveSize + flat_id, input[i]); } ::rocprim::wave_barrier(); const auto& storage_buffer = storage.buffer.get_unsafe_array(); @@ -902,13 +910,13 @@ class warp_exchange /// \brief Transposes a striped arrangement of items to a blocked arrangement /// across the warp, using warp shuffle operations. - /// Uses an optimized implementation for when WarpSize is equal to ItemsPerThread. + /// Uses an optimized implementation for when VirtualWaveSize is equal to ItemsPerThread. /// Caution: this API is experimental. Performance might not be consistent. /// One of these following conditions must be satisfied: - /// 1. WarpSize is equal to ItemsPerThread - /// 2. ItemsPerThread % ROCPRIM_QUAD_SIZE == 0 && ItemsPerThread % (WarpSize / ROCPRIM_QUAD_SIZE) == 0 - /// 3. ItemsPerThread is divisible by WarpSize - /// 4. WarpSize is divisible by ItemsPerThread + /// 1. VirtualWaveSize is equal to ItemsPerThread + /// 2. ItemsPerThread % ROCPRIM_QUAD_SIZE == 0 && ItemsPerThread % (VirtualWaveSize / ROCPRIM_QUAD_SIZE) == 0 + /// 3. ItemsPerThread is divisible by VirtualWaveSize + /// 4. VirtualWaveSize is divisible by ItemsPerThread /// /// \tparam U [inferred] the output type. /// @@ -935,8 +943,8 @@ class warp_exchange /// } /// \endcode template - ROCPRIM_DEVICE ROCPRIM_INLINE void striped_to_blocked_shuffle(const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread]) + ROCPRIM_DEVICE ROCPRIM_INLINE + void striped_to_blocked_shuffle(const T (&input)[ItemsPerThread], U (&output)[ItemsPerThread]) { static_assert( conditions::is_equal_size || conditions::is_quad_compatible_sb @@ -948,7 +956,7 @@ class warp_exchange /// \brief Orders \p input values according to ranks using temporary storage, /// then writes the values to \p output in a striped manner. - /// No values in \p ranks should exists that exceed \p WarpSize*ItemsPerThread-1 . + /// No values in \p ranks should exists that exceed \p VirtualWaveSize*ItemsPerThread-1 . /// \tparam U [inferred] the output type. /// /// \param [in] input array that data is loaded from. @@ -987,16 +995,15 @@ class warp_exchange /// \endcode template ROCPRIM_DEVICE ROCPRIM_INLINE - void scatter_to_striped( - const T (&input)[ItemsPerThread], - U (&output)[ItemsPerThread], - const OffsetT (&ranks)[ItemsPerThread], - storage_type& storage) + void scatter_to_striped(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const OffsetT (&ranks)[ItemsPerThread], + storage_type& storage) { - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); ROCPRIM_UNROLL - for (unsigned int i = 0; i < ItemsPerThread; i++) + for(unsigned int i = 0; i < ItemsPerThread; i++) { storage.buffer.emplace(ranks[i], input[i]); } @@ -1004,14 +1011,74 @@ class warp_exchange const auto& storage_buffer = storage.buffer.get_unsafe_array(); ROCPRIM_UNROLL - for (unsigned int i = 0; i < ItemsPerThread; i++) + for(unsigned int i = 0; i < ItemsPerThread; i++) { - unsigned int item_offset = (i * WarpSize) + flat_id; + unsigned int item_offset = (i * VirtualWaveSize) + flat_id; output[i] = storage_buffer[item_offset]; } } }; +#ifndef DOXYGEN_SHOULD_SKIP_THIS + +template +class warp_exchange +{ +private: + using warp_exchange_wave32 = warp_exchange; + using warp_exchange_wave64 = warp_exchange; + using dispatch + = ::rocprim::detail::dispatch_wave_size; + +public: + using storage_type = typename dispatch::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto blocked_to_striped(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.blocked_to_striped(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto blocked_to_striped_shuffle(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.blocked_to_striped_shuffle(args...); }, + args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto striped_to_blocked(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.striped_to_blocked(args...); }, args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto striped_to_blocked_shuffle(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.striped_to_blocked_shuffle(args...); }, + args...); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto scatter_to_striped(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.scatter_to_striped(args...); }, args...); + } +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + END_ROCPRIM_NAMESPACE /// @} diff --git a/rocprim/include/rocprim/warp/warp_load.hpp b/rocprim/include/rocprim/warp/warp_load.hpp index b2c904208..31456d3d9 100644 --- a/rocprim/include/rocprim/warp/warp_load.hpp +++ b/rocprim/include/rocprim/warp/warp_load.hpp @@ -21,12 +21,12 @@ #ifndef ROCPRIM_WARP_WARP_LOAD_HPP_ #define ROCPRIM_WARP_WARP_LOAD_HPP_ +#include "../block/block_load_func.hpp" #include "../config.hpp" -#include "../intrinsics.hpp" #include "../detail/various.hpp" +#include "../intrinsics/arch.hpp" #include "warp_exchange.hpp" -#include "../block/block_load_func.hpp" /// \addtogroup warpmodule /// @{ @@ -80,9 +80,11 @@ enum class warp_load_method /// \tparam T the input/output type. /// \tparam ItemsPerThread the number of items to be processed by /// each thread. -/// \tparam WarpSize the number of threads in the warp. It must be a divisor of the +/// \tparam VirtualWaveSize the number of threads in the warp. It must be a divisor of the /// kernel block size. /// \tparam Method the method to load data. +/// \tparam TargetWaveSize The hardware wavefront size. It can be used to specialize +/// the targeted wavefront size when compiling to SPIR-V. /// /// \par Overview /// * The \p warp_load class has a number of different methods to load data: @@ -115,20 +117,24 @@ enum class warp_load_method /// \endparblock template + unsigned int VirtualWaveSize = ::rocprim::arch::wavefront::min_size(), + warp_load_method Method = warp_load_method::warp_load_direct, + ::rocprim::arch::wavefront::target TargetWaveSize + = ::rocprim::arch::wavefront::get_target(), + typename Enabled = void> class warp_load { - static_assert(::rocprim::detail::is_power_of_two(WarpSize), + static_assert(::rocprim::detail::is_power_of_two(VirtualWaveSize), "Logical warp size must be a power of two."); - ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT( - WarpSize <= ::rocprim::arch::wavefront::min_size(), - "Logical warp size cannot be larger than physical warp size."); private: using storage_type_ = typename ::rocprim::detail::empty_storage_type; public: + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE warp_load() + { + detail::check_virtual_wave_size(); + } /// \brief Struct used to allocate a temporary memory that is required for thread /// communication during operations provided by related parallel primitive. /// @@ -166,7 +172,7 @@ class warp_load static_assert(std::is_convertible::value, "The type T must be such that an object of type InputIterator " "can be dereferenced and then implicitly converted to T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); block_load_direct_blocked(flat_id, input, items); } @@ -195,7 +201,7 @@ class warp_load static_assert(std::is_convertible::value, "The type T must be such that an object of type InputIterator " "can be dereferenced and then implicitly converted to T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); block_load_direct_blocked(flat_id, input, items, valid); } @@ -229,28 +235,68 @@ class warp_load static_assert(std::is_convertible::value, "The type T must be such that an object of type InputIterator " "can be dereferenced and then implicitly converted to T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); block_load_direct_blocked(flat_id, input, items, valid, out_of_bounds); } }; #ifndef DOXYGEN_SHOULD_SKIP_THIS +template +class warp_load +{ +private: + using warp_load_wave32 = warp_load; + using warp_load_wave64 = warp_load; + + using dispatch = detail::dispatch_wave_size; + +public: + using storage_type = typename dispatch::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto load(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.load(args...); }, args...); + } +}; -template< - class T, - unsigned int ItemsPerThread, - unsigned int WarpSize -> -class warp_load +template +class warp_load> { - static_assert(::rocprim::detail::is_power_of_two(WarpSize), + static_assert(::rocprim::detail::is_power_of_two(VirtualWaveSize), "Logical warp size must be a power of two."); - ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT( - WarpSize <= ::rocprim::arch::wavefront::min_size(), - "Logical warp size cannot be larger than physical warp size."); public: + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE warp_load() + { + detail::check_virtual_wave_size(); + } + using storage_type = typename ::rocprim::detail::empty_storage_type; template @@ -263,8 +309,8 @@ class warp_load::value, "The type T must be such that an object of type InputIterator " "can be dereferenced and then implicitly converted to T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); - block_load_direct_warp_striped(flat_id, input, items); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_warp_striped(flat_id, input, items); } template @@ -278,8 +324,8 @@ class warp_load::value, "The type T must be such that an object of type InputIterator " "can be dereferenced and then implicitly converted to T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); - block_load_direct_warp_striped(flat_id, input, items, valid); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_warp_striped(flat_id, input, items, valid); } template< @@ -297,26 +343,35 @@ class warp_load::value, "The type T must be such that an object of type InputIterator " "can be dereferenced and then implicitly converted to T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); - block_load_direct_warp_striped(flat_id, input, items, valid, - out_of_bounds); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_warp_striped(flat_id, + input, + items, + valid, + out_of_bounds); } }; -template< - class T, - unsigned int ItemsPerThread, - unsigned int WarpSize -> -class warp_load +template +class warp_load> { - static_assert(::rocprim::detail::is_power_of_two(WarpSize), + static_assert(::rocprim::detail::is_power_of_two(VirtualWaveSize), "Logical warp size must be a power of two."); - ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT( - WarpSize <= ::rocprim::arch::wavefront::min_size(), - "Logical warp size cannot be larger than physical warp size."); public: + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE warp_load() + { + detail::check_virtual_wave_size(); + } + using storage_type = typename ::rocprim::detail::empty_storage_type; ROCPRIM_DEVICE ROCPRIM_INLINE @@ -324,7 +379,7 @@ class warp_load(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); block_load_direct_blocked_vectorized(flat_id, input, items); } @@ -338,7 +393,7 @@ class warp_load::value, "The type T must be such that an object of type InputIterator " "can be dereferenced and then implicitly converted to T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); block_load_direct_blocked(flat_id, input, items); } @@ -353,7 +408,7 @@ class warp_load::value, "The type T must be such that an object of type InputIterator " "can be dereferenced and then implicitly converted to T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); block_load_direct_blocked(flat_id, input, items, valid); } @@ -372,27 +427,34 @@ class warp_load::value, "The type T must be such that an object of type InputIterator " "can be dereferenced and then implicitly converted to T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); block_load_direct_blocked(flat_id, input, items, valid, out_of_bounds); } }; -template< - class T, - unsigned int ItemsPerThread, - unsigned int WarpSize -> -class warp_load +template +class warp_load> { - static_assert(::rocprim::detail::is_power_of_two(WarpSize), + static_assert(::rocprim::detail::is_power_of_two(VirtualWaveSize), "Logical warp size must be a power of two."); - ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT( - WarpSize <= ::rocprim::arch::wavefront::min_size(), - "Logical warp size cannot be larger than physical warp size."); + +public: + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE warp_load() + { + detail::check_virtual_wave_size(); + } private: - using exchange_type = ::rocprim::warp_exchange; + using exchange_type = ::rocprim::warp_exchange; public: using storage_type = typename exchange_type::storage_type; @@ -407,8 +469,8 @@ class warp_load::value, "The type T must be such that an object of type InputIterator " "can be dereferenced and then implicitly converted to T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); - block_load_direct_warp_striped(flat_id, input, items); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_warp_striped(flat_id, input, items); exchange_type().striped_to_blocked(items, items, storage); } @@ -423,8 +485,8 @@ class warp_load::value, "The type T must be such that an object of type InputIterator " "can be dereferenced and then implicitly converted to T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); - block_load_direct_warp_striped(flat_id, input, items, valid); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_warp_striped(flat_id, input, items, valid); exchange_type().striped_to_blocked(items, items, storage); } @@ -443,9 +505,12 @@ class warp_load::value, "The type T must be such that an object of type InputIterator " "can be dereferenced and then implicitly converted to T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); - block_load_direct_warp_striped(flat_id, input, items, valid, - out_of_bounds); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_load_direct_warp_striped(flat_id, + input, + items, + valid, + out_of_bounds); exchange_type().striped_to_blocked(items, items, storage); } }; diff --git a/rocprim/include/rocprim/warp/warp_reduce.hpp b/rocprim/include/rocprim/warp/warp_reduce.hpp index e90d2885b..1f6766036 100644 --- a/rocprim/include/rocprim/warp/warp_reduce.hpp +++ b/rocprim/include/rocprim/warp/warp_reduce.hpp @@ -26,8 +26,8 @@ #include "../config.hpp" #include "../detail/various.hpp" -#include "../intrinsics.hpp" #include "../functional.hpp" +#include "../intrinsics.hpp" #include "../types.hpp" #include "detail/warp_reduce_crosslane.hpp" @@ -41,15 +41,15 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// Select warp_reduce implementation based WarpSize -template +// Select warp_reduce implementation based VirtualWaveSize +template struct select_warp_reduce_impl { using type = typename std::conditional< // can we use crosslane (DPP or shuffle-based) implementation? - detail::is_warpsize_shuffleable::value, - detail::warp_reduce_crosslane, // yes - detail::warp_reduce_shared_mem // no + detail::is_warpsize_shuffleable::value, + detail::warp_reduce_crosslane, // yes + detail::warp_reduce_shared_mem // no >::type; }; @@ -60,24 +60,24 @@ struct select_warp_reduce_impl /// warp. /// /// \tparam T the input/output type. -/// \tparam WarpSize the size of logical warp size, which can be equal to or less than +/// \tparam VirtualWaveSize the size of logical warp size, which can be equal to or less than /// the size of hardware warp (see rocprim::arch::wavefront::min_size()). Reduce operations are performed -/// separately within groups determined by WarpSize. +/// separately within groups determined by VirtualWaveSize. /// \tparam UseAllReduce input parameter to determine whether to broadcast final reduction /// value to all threads (default is false). /// /// \par Overview -/// * \p WarpSize must be equal to or less than the size of hardware warp (see +/// * \p VirtualWaveSize must be equal to or less than the size of hardware warp (see /// rocprim::arch::wavefront::min_size()). If it is less, reduce is performed separately within groups -/// determined by WarpSize. \n -/// For example, if \p WarpSize is 4, hardware warp is 64, reduction will be performed in logical +/// determined by VirtualWaveSize. \n +/// For example, if \p VirtualWaveSize is 4, hardware warp is 64, reduction will be performed in logical /// warps grouped like this: `{ {0, 1, 2, 3}, {4, 5, 6, 7 }, ..., {60, 61, 62, 63} }` /// (thread is represented here by its id within hardware warp). -/// * Logical warp is a group of \p WarpSize consecutive threads from the same hardware warp. +/// * Logical warp is a group of \p VirtualWaveSize consecutive threads from the same hardware warp. /// * Supports non-commutative reduce operators. However, a reduce operator should be /// associative. When used with non-associative functions the results may be non-deterministic /// and/or vary in precision. -/// * Number of threads executing warp_reduce's function must be a multiple of \p WarpSize; +/// * Number of threads executing warp_reduce's function must be a multiple of \p VirtualWaveSize; /// * All threads from a logical warp must be in the same hardware warp. /// /// \par Examples @@ -106,17 +106,25 @@ struct select_warp_reduce_impl /// } /// \endcode /// \endparblock -template +template class warp_reduce #ifndef DOXYGEN_SHOULD_SKIP_THIS - : private detail::select_warp_reduce_impl::type + : private detail::select_warp_reduce_impl::type #endif { - using base_type = typename detail::select_warp_reduce_impl::type; + using base_type = + typename detail::select_warp_reduce_impl::type; + + // Check if VirtualWaveSize is valid for the targets - // Check if WarpSize is valid for the targets - static_assert(WarpSize <= ROCPRIM_MAX_WARP_SIZE, - "WarpSize can't be greater than hardware warp size."); +public: + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE warp_reduce() + { + detail::check_virtual_wave_size(); + } public: /// \brief Struct used to allocate a temporary memory that is required for thread @@ -173,23 +181,35 @@ class warp_reduce /// } /// \endcode /// \endparblock - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionWarpSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto reduce(T input, T& output, storage_type& storage, BinaryFunction reduce_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionWarpSize <= arch::wavefront::max_size()), void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::reduce(input, output, storage, reduce_op); } /// \brief Performs reduction across threads in a logical warp. /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionWarpSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto reduce(T, T&, storage_type&, BinaryFunction reduce_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionWarpSize > arch::wavefront::max_size()), void>::type { (void)reduce_op; ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp " @@ -243,24 +263,36 @@ class warp_reduce /// } /// \endcode /// \endparblock - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionWarpSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto reduce(T input, T& output, int valid_items, storage_type& storage, BinaryFunction reduce_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionWarpSize <= arch::wavefront::max_size()), void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::reduce(input, output, valid_items, storage, reduce_op); } /// \brief Performs reduction across threads in a logical warp. /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionWarpSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto reduce(T, T&, int, storage_type&, BinaryFunction reduce_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionWarpSize > arch::wavefront::max_size()), void>::type { (void)reduce_op; ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp " @@ -288,15 +320,25 @@ class warp_reduce /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). template, - unsigned int FunctionWarpSize = WarpSize> + unsigned int FunctionWarpSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto head_segmented_reduce(T input, T& output, Flag flag, storage_type& storage, BinaryFunction reduce_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionWarpSize <= arch::wavefront::max_size()), void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::head_segmented_reduce(input, output, flag, storage, reduce_op); } @@ -304,11 +346,11 @@ class warp_reduce /// Invalid Warp Size template, - unsigned int FunctionWarpSize = WarpSize> + unsigned int FunctionWarpSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto head_segmented_reduce( T, T&, Flag, storage_type&, BinaryFunction reduce_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionWarpSize > arch::wavefront::max_size()), void>::type { (void)reduce_op; ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp " @@ -336,15 +378,25 @@ class warp_reduce /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). template, - unsigned int FunctionWarpSize = WarpSize> + unsigned int FunctionWarpSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto tail_segmented_reduce(T input, T& output, Flag flag, storage_type& storage, BinaryFunction reduce_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionWarpSize <= arch::wavefront::max_size()), void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::tail_segmented_reduce(input, output, flag, storage, reduce_op); } @@ -352,11 +404,11 @@ class warp_reduce /// Invalid Warp Size template, - unsigned int FunctionWarpSize = WarpSize> + unsigned int FunctionWarpSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto tail_segmented_reduce( T, T&, Flag, storage_type&, BinaryFunction reduce_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionWarpSize > arch::wavefront::max_size()), void>::type { (void)reduce_op; ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp " diff --git a/rocprim/include/rocprim/warp/warp_scan.hpp b/rocprim/include/rocprim/warp/warp_scan.hpp index 8248dd9b8..ed1023768 100644 --- a/rocprim/include/rocprim/warp/warp_scan.hpp +++ b/rocprim/include/rocprim/warp/warp_scan.hpp @@ -26,8 +26,8 @@ #include "../config.hpp" #include "../detail/various.hpp" -#include "../intrinsics.hpp" #include "../functional.hpp" +#include "../intrinsics.hpp" #include "../types.hpp" #include "detail/warp_scan_crosslane.hpp" @@ -41,15 +41,15 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// Select warp_scan implementation based WarpSize -template +// Select warp_scan implementation based VirtualWaveSize +template struct select_warp_scan_impl { using type = typename std::conditional< // can we use crosslane (DPP or shuffle-based) implementation? - detail::is_warpsize_shuffleable::value, - detail::warp_scan_crosslane, // yes - detail::warp_scan_shared_mem // no + detail::is_warpsize_shuffleable::value, + detail::warp_scan_crosslane, // yes + detail::warp_scan_shared_mem // no >::type; }; @@ -60,22 +60,22 @@ struct select_warp_scan_impl /// threads in a hardware warp. /// /// \tparam T the input/output type. -/// \tparam WarpSize the size of logical warp size, which can be equal to or less than +/// \tparam VirtualWaveSize the size of logical warp size, which can be equal to or less than /// the size of hardware warp (see rocprim::arch::wavefront::min_size()). Scan operations are performed -/// separately within groups determined by WarpSize. +/// separately within groups determined by VirtualWaveSize. /// /// \par Overview -/// * \p WarpSize must be equal to or less than the size of hardware warp (see +/// * \p VirtualWaveSize must be equal to or less than the size of hardware warp (see /// rocprim::arch::wavefront::min_size()). If it is less, scan is performed separately within groups -/// determined by WarpSize. \n -/// For example, if \p WarpSize is 4, hardware warp is 64, scan will be performed in logical +/// determined by VirtualWaveSize. \n +/// For example, if \p VirtualWaveSize is 4, hardware warp is 64, scan will be performed in logical /// warps grouped like this: `{ {0, 1, 2, 3}, {4, 5, 6, 7 }, ..., {60, 61, 62, 63} }` /// (thread is represented here by its id within hardware warp). -/// * Logical warp is a group of \p WarpSize consecutive threads from the same hardware warp. +/// * Logical warp is a group of \p VirtualWaveSize consecutive threads from the same hardware warp. /// * Supports non-commutative scan operators. However, a scan operator should be /// associative. When used with non-associative functions the results may be non-deterministic /// and/or vary in precision. -/// * Number of threads executing warp_scan's function must be a multiple of \p WarpSize; +/// * Number of threads executing warp_scan's function must be a multiple of \p VirtualWaveSize; /// * All threads from a logical warp must be in the same hardware warp. /// /// \par Examples @@ -104,17 +104,19 @@ struct select_warp_scan_impl /// } /// \endcode /// \endparblock -template +template class warp_scan #ifndef DOXYGEN_SHOULD_SKIP_THIS - : private detail::select_warp_scan_impl::type + : private detail::select_warp_scan_impl::type #endif { - using base_type = typename detail::select_warp_scan_impl::type; + using base_type = typename detail::select_warp_scan_impl::type; - // Check if WarpSize is valid for the targets - static_assert(WarpSize <= ROCPRIM_MAX_WARP_SIZE, - "WarpSize can't be greater than hardware warp size."); + // Check if VirtualWaveSize is valid for the targets + static_assert(VirtualWaveSize <= ROCPRIM_MAX_WARP_SIZE, + "VirtualWaveSize can't be greater than hardware warp size."); public: /// \brief Struct used to allocate a temporary memory that is required for thread @@ -175,23 +177,36 @@ class warp_scan /// output values in the first logical warp will be {1, -2, -2, -4, ..., -32}, in the second: /// {33, -34, -34, -36, ..., -64} etc. /// \endparblock - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto inclusive_scan(T input, T& output, storage_type& storage, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::inclusive_scan(input, output, storage, scan_op); } /// \brief Performs inclusive scan across threads in a logical warp. /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto inclusive_scan(T, T&, storage_type&, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)scan_op; ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size " @@ -250,24 +265,36 @@ class warp_scan /// and the value for seeding the scan is -1, then output values in the first logical warp will be /// {-1, -2, -2, -4, ..., -32},, in the second: {-1, -34, -34, -36, ..., -64}, etc. /// \endparblock - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto inclusive_scan(T input, T& output, storage_type& storage, T init, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + } + } base_type::inclusive_scan(input, output, storage, scan_op, init); } /// \brief Performs seeded inclusive scan across threads in a logical warp. /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto inclusive_scan(T, T&, storage_type&, T, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)scan_op; ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size " @@ -324,24 +351,36 @@ class warp_scan /// \p output values in the every logical warp will be {1, 2, 3, 4, ..., 64}. /// The \p reduction will be equal \p 64. /// \endparblock - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto inclusive_scan(T input, T& output, T& reduction, storage_type& storage, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + } + } base_type::inclusive_scan(input, output, reduction, storage, scan_op); } /// \brief Performs inclusive scan and reduction across threads in a logical warp. /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto inclusive_scan(T, T&, T&, storage_type&, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)scan_op; ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size " @@ -357,6 +396,7 @@ class warp_scan /// \param [in] input thread input value. /// \param [out] output reference to a thread output value. May be aliased with \p input. /// \param [out] reduction result of reducing of all \p input values in logical warp. + /// This does not include \p init. /// \param [in] storage reference to a temporary storage object of type storage_type. /// \param [in] init initial value to seed the inclusive scan. /// \param [in] scan_op binary operation function object that will be used for scan. @@ -402,7 +442,8 @@ class warp_scan /// \p output values in the every logical warp will be {2, 3, 4, 5, ..., 65}. /// The \p reduction will be equal \p 65. /// \endparblock - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto inclusive_scan(T input, T& output, @@ -410,17 +451,28 @@ class warp_scan storage_type& storage, T init, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + } + } base_type::inclusive_scan(input, output, reduction, storage, scan_op, init); } /// \brief Performs seeded inclusive scan and reduction across threads in a logical warp. /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto inclusive_scan(T, T&, T&, storage_type&, T, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)scan_op; ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size " @@ -480,24 +532,36 @@ class warp_scan /// warp will be {100, 1, -2, -2, -4, ..., -30}, in the second: /// {100, 33, -34, -34, -36, ..., -62} etc. /// \endparblock - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T input, T& output, T init, storage_type& storage, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + } + } base_type::exclusive_scan(input, output, init, storage, scan_op); } /// \brief Performs exclusive scan across threads in a logical warp. /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T, T&, T, storage_type&, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)scan_op; ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size " @@ -558,7 +622,8 @@ class warp_scan /// {1, 1, ..., 1, 1}, then \p output values in every logical warp will be /// {10, 11, 12, 13, ..., 73}. The \p reduction will be 64. /// \endparblock - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T input, T& output, @@ -566,17 +631,28 @@ class warp_scan T& reduction, storage_type& storage, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + } + } base_type::exclusive_scan(input, output, init, reduction, storage, scan_op); } /// \brief Performs exclusive scan and reduction across threads in a logical warp. /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T, T&, T, T&, storage_type&, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)scan_op; ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size " @@ -592,14 +668,15 @@ class warp_scan /// thread of each logical warp. /// \param [in] storage Reference to a temporary storage object of type storage_type. /// \param scan_op The function object used to combine elements used for the scan - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T input, T& output, storage_type& storage, BinaryFunction scan_op = BinaryFunction()) #ifndef DOXYGEN_DOCUMENTATION_BUILD - -> std::enable_if_t + -> std::enable_if_t #else -> void #endif @@ -608,13 +685,14 @@ class warp_scan } /// \cond - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T /*input*/, T& /*output*/, storage_type& /*storage*/, BinaryFunction /*scan_op*/ = BinaryFunction()) - -> std::enable_if_t<(FunctionWarpSize > arch::wavefront::min_size())> + -> std::enable_if_t<(FunctionVirtualWaveSize > arch::wavefront::max_size())> { ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size." " Aborting warp scan."); @@ -631,7 +709,8 @@ class warp_scan /// \param[out] reduction Result of reducing of all `input` values in the logical warp. /// \param [in] storage Reference to a temporary storage object of type storage_type. /// \param scan_op The function object used to combine elements used for the scan - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T input, T& output, @@ -639,7 +718,7 @@ class warp_scan T& reduction, BinaryFunction scan_op = BinaryFunction()) #ifndef DOXYGEN_DOCUMENTATION_BUILD - -> std::enable_if_t + -> std::enable_if_t #else -> void #endif @@ -648,14 +727,15 @@ class warp_scan } /// \cond - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto exclusive_scan(T /*input*/, T& /*output*/, storage_type& /*storage*/, T& /*reduction*/, BinaryFunction /*scan_op*/ = BinaryFunction()) - -> std::enable_if_t<(FunctionWarpSize > arch::wavefront::min_size())> + -> std::enable_if_t<(FunctionVirtualWaveSize > arch::wavefront::max_size())> { ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size." " Aborting warp scan."); @@ -720,7 +800,8 @@ class warp_scan /// logical warp will be {100, 1, -2, -2, -4, ..., -30}, in the second: /// {100, 33, -34, -34, -36, ..., -62} etc. /// \endparblock - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto scan(T input, T& inclusive_output, @@ -728,17 +809,28 @@ class warp_scan T init, storage_type& storage, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + } + } base_type::scan(input, inclusive_output, exclusive_output, init, storage, scan_op); } /// \brief Performs inclusive and exclusive scan operations across threads /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto scan(T, T&, T&, T, storage_type&, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)scan_op; ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size " @@ -804,7 +896,8 @@ class warp_scan /// {1, 2, 3, 4, ..., 63, 64}, and \p ex_output values in every logical warp will /// be {10, 11, 12, 13, ..., 73}. The \p reduction will be 64. /// \endparblock - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto scan(T input, T& inclusive_output, @@ -813,20 +906,34 @@ class warp_scan T& reduction, storage_type& storage, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { - base_type::scan( - input, inclusive_output, exclusive_output, init, reduction, - storage, scan_op - ); + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + } + } + base_type::scan(input, + inclusive_output, + exclusive_output, + init, + reduction, + storage, + scan_op); } /// \brief Performs inclusive and exclusive scan operations across threads /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto scan(T, T&, T&, T, T&, storage_type&, BinaryFunction scan_op = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)scan_op; ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp size " @@ -843,47 +950,42 @@ class warp_scan /// \par Storage reusage /// Synchronization barrier should be placed before \p storage is reused /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). - template + template ROCPRIM_DEVICE ROCPRIM_INLINE auto broadcast(T input, const unsigned int src_lane, storage_type& storage) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), T>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), T>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + } + } return base_type::broadcast(input, src_lane, storage); } /// \brief Broadcasts value from one thread to all threads in logical warp. /// Invalid Warp Size - template + template ROCPRIM_DEVICE ROCPRIM_INLINE auto broadcast(T, const unsigned int, storage_type&) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), T>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), T>::type { ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp " "size. Aborting warp sort."); return T(); } -#ifndef DOXYGEN_SHOULD_SKIP_THIS -protected: - // These undocumented functions are used by hipCUB prior to version 3.1 - template - [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE - auto to_exclusive(T inclusive_input, T& exclusive_output, storage_type& storage) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type - { - return base_type::to_exclusive(inclusive_input, exclusive_output, storage); - } - - template - [[deprecated]] ROCPRIM_DEVICE ROCPRIM_INLINE - auto to_exclusive(T, T&, storage_type&) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type - { - ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp " - "size. Aborting warp sort."); - return; - } -#endif + /// \brief Broadcasts value from one thread to all threads in logical warp. + /// Invalid Warp Size + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto broadcast(T, const unsigned int, storage_type&) -> + typename std::enable_if<(!detail::is_power_of_two(FunctionVirtualWaveSize)), T>::type + = delete; }; END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/warp/warp_sort.hpp b/rocprim/include/rocprim/warp/warp_sort.hpp index 1760ef710..50f10fdb0 100644 --- a/rocprim/include/rocprim/warp/warp_sort.hpp +++ b/rocprim/include/rocprim/warp/warp_sort.hpp @@ -41,19 +41,19 @@ BEGIN_ROCPRIM_NAMESPACE /// bitonic sort, and only accepts warp sizes that are powers of two. /// /// \tparam Key Data type for parameter Key -/// \tparam WarpSize [optional] The number of threads in a warp +/// \tparam VirtualWaveSize [optional] The number of threads in a warp /// \tparam Value [optional] Data type for parameter Value. By default, it's empty_type /// /// \par Overview -/// * \p WarpSize must be power of two. -/// * \p WarpSize must be equal to or less than the size of hardware warp (see -/// rocprim::arch::wavefront::min_size()). If it is less, sort is performed separately within groups -/// determined by WarpSize. -/// For example, if \p WarpSize is 4, hardware warp is 64, sort will be performed in logical +/// * \p VirtualWaveSize must be power of two. +/// * \p VirtualWaveSize must be equal to or less than the size of hardware warp (see +/// rocprim::arch::wavefront::max_size()). If it is less, sort is performed separately within groups +/// determined by VirtualWaveSize. +/// For example, if \p VirtualWaveSize is 4, hardware warp is 64, sort will be performed in logical /// warps grouped like this: `{ {0, 1, 2, 3}, {4, 5, 6, 7 }, ..., {60, 61, 62, 63} }` /// (thread is represented here by its id within hardware warp). /// * Accepts custom compare_functions for sorting across a warp. -/// * Number of threads executing warp_sort's function must be a multiple of \p WarpSize. +/// * Number of threads executing warp_sort's function must be a multiple of \p VirtualWaveSize. /// /// \par Stability /// \p warp_sort is not stable: it doesn't necessarily preserve the relative ordering @@ -100,14 +100,17 @@ BEGIN_ROCPRIM_NAMESPACE /// } /// \endcode /// \endparblock -template -class warp_sort : detail::warp_sort_shuffle +template +class warp_sort : detail::warp_sort_shuffle { - using base_type = typename detail::warp_sort_shuffle; + using base_type = typename detail::warp_sort_shuffle; - // Check if WarpSize is valid for the targets - static_assert(WarpSize <= ROCPRIM_MAX_WARP_SIZE, - "WarpSize can't be greater than hardware warp size."); + // Check if VirtualWaveSize is valid for the targets + static_assert(VirtualWaveSize <= ROCPRIM_MAX_WARP_SIZE, + "VirtualWaveSize can't be greater than hardware warp size."); public: /// \brief Struct used to allocate a temporary memory that is required for thread @@ -130,20 +133,33 @@ class warp_sort : detail::warp_sort_shuffle /// The signature of the function should be equivalent to the following: /// bool f(const T &a, const T &b);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key& thread_key, BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::sort(thread_key, compare_function); } /// \brief Warp sort for any data type. /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key&, BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)compare_function; // disables unused parameter warning ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp " @@ -162,25 +178,36 @@ class warp_sort : detail::warp_sort_shuffle /// bool f(const T &a, const T &b);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. template, - unsigned int FunctionWarpSize = WarpSize> + class BinaryFunction = ::rocprim::less, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key (&thread_keys)[ItemsPerThread], BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::sort(thread_keys, compare_function); } /// \brief Warp sort for any data type. /// Invalid Warp Size template, - unsigned int FunctionWarpSize = WarpSize> + class BinaryFunction = ::rocprim::less, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key (&thread_keys)[ItemsPerThread], BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)thread_keys; // disables unused parameter warning (void)compare_function; // disables unused parameter warning @@ -217,13 +244,25 @@ class warp_sort : detail::warp_sort_shuffle /// ... /// } /// \endcode - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key& thread_key, storage_type& storage, BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::sort( thread_key, storage, compare_function ); @@ -231,10 +270,11 @@ class warp_sort : detail::warp_sort_shuffle /// \brief Warp sort for any data type using temporary storage. /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key&, storage_type&, BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)compare_function; // disables unused parameter warning ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp " @@ -271,14 +311,25 @@ class warp_sort : detail::warp_sort_shuffle /// } /// \endcode template, - unsigned int FunctionWarpSize = WarpSize> + class BinaryFunction = ::rocprim::less, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key (&thread_keys)[ItemsPerThread], storage_type& storage, BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::sort( thread_keys, storage, compare_function ); @@ -287,13 +338,13 @@ class warp_sort : detail::warp_sort_shuffle /// \brief Warp sort for any data type using temporary storage. /// Invalid Warp Size template, - unsigned int FunctionWarpSize = WarpSize> + class BinaryFunction = ::rocprim::less, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key (&thread_keys)[ItemsPerThread], storage_type&, BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)thread_keys; // disables unused parameter warning (void)compare_function; // disables unused parameter warning @@ -313,13 +364,25 @@ class warp_sort : detail::warp_sort_shuffle /// The signature of the function should be equivalent to the following: /// bool f(const T &a, const T &b);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key& thread_key, Value& thread_value, BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::sort( thread_key, thread_value, compare_function ); @@ -327,10 +390,11 @@ class warp_sort : detail::warp_sort_shuffle /// \brief Warp sort by key for any data type. /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key&, Value&, BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)compare_function; // disables unused parameter warning ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp " @@ -350,14 +414,25 @@ class warp_sort : detail::warp_sort_shuffle /// bool f(const T &a, const T &b);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. template, - unsigned int FunctionWarpSize = WarpSize> + class BinaryFunction = ::rocprim::less, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key (&thread_keys)[ItemsPerThread], Value (&thread_values)[ItemsPerThread], BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::sort( thread_keys, thread_values, compare_function ); @@ -366,13 +441,13 @@ class warp_sort : detail::warp_sort_shuffle /// \brief Warp sort by key for any data type. /// Invalid Warp Size template, - unsigned int FunctionWarpSize = WarpSize> + class BinaryFunction = ::rocprim::less, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key (&thread_keys)[ItemsPerThread], Value (&thread_values)[ItemsPerThread], BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)thread_keys; // disables unused parameter warning (void)thread_values; // disables unused parameter warning @@ -411,14 +486,26 @@ class warp_sort : detail::warp_sort_shuffle /// ... /// } /// \endcode - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key& thread_key, Value& thread_value, storage_type& storage, BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::sort( thread_key, thread_value, storage, compare_function ); @@ -426,10 +513,11 @@ class warp_sort : detail::warp_sort_shuffle /// \brief Warp sort by key for any data type using temporary storage. /// Invalid Warp Size - template, unsigned int FunctionWarpSize = WarpSize> + template, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key&, Value&, storage_type&, BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)compare_function; // disables unused parameter warning ROCPRIM_PRINT_ERROR_ONCE("Specified warp size exceeds current hardware supported warp " @@ -467,15 +555,26 @@ class warp_sort : detail::warp_sort_shuffle /// } /// \endcode template, - unsigned int FunctionWarpSize = WarpSize> + class BinaryFunction = ::rocprim::less, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key (&thread_keys)[ItemsPerThread], Value (&thread_values)[ItemsPerThread], storage_type& storage, BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize <= arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize <= arch::wavefront::max_size()), + void>::type { + if constexpr(TargetWaveSize == ::rocprim::arch::wavefront::target::dynamic) + { + if(VirtualWaveSize > ::rocprim::arch::wavefront::size()) + { + ROCPRIM_PRINT_ERROR_ONCE( + "Specified warp size exceeds current hardware supported warp " + "size. Aborting warp sort."); + return; + } + } base_type::sort( thread_keys, thread_values, storage, compare_function ); @@ -484,14 +583,14 @@ class warp_sort : detail::warp_sort_shuffle /// \brief Warp sort by key for any data type using temporary storage. /// Invalid Warp Size template, - unsigned int FunctionWarpSize = WarpSize> + class BinaryFunction = ::rocprim::less, + unsigned int FunctionVirtualWaveSize = VirtualWaveSize> ROCPRIM_DEVICE ROCPRIM_INLINE auto sort(Key (&thread_keys)[ItemsPerThread], Value (&thread_values)[ItemsPerThread], storage_type&, BinaryFunction compare_function = BinaryFunction()) -> - typename std::enable_if<(FunctionWarpSize > arch::wavefront::min_size()), void>::type + typename std::enable_if<(FunctionVirtualWaveSize > arch::wavefront::max_size()), void>::type { (void)thread_keys; // disables unused parameter warning (void)thread_values; // disables unused parameter warning diff --git a/rocprim/include/rocprim/warp/warp_store.hpp b/rocprim/include/rocprim/warp/warp_store.hpp index e52381759..e4ef14630 100644 --- a/rocprim/include/rocprim/warp/warp_store.hpp +++ b/rocprim/include/rocprim/warp/warp_store.hpp @@ -82,7 +82,7 @@ enum class warp_store_method /// \tparam T the output/output type. /// \tparam ItemsPerThread the number of items to be processed by /// each thread. -/// \tparam WarpSize the number of threads in a warp. It must be a divisor of the +/// \tparam VirtualWaveSize the number of threads in a warp. It must be a divisor of the /// kernel block size. /// \tparam Method the method to store data. /// @@ -117,20 +117,25 @@ enum class warp_store_method /// \endparblock template + unsigned int VirtualWaveSize = ::rocprim::arch::wavefront::min_size(), + warp_store_method Method = warp_store_method::warp_store_direct, + ::rocprim::arch::wavefront::target TargetWaveSize + = ::rocprim::arch::wavefront::get_target(), + typename Enabled = void> class warp_store { - static_assert(::rocprim::detail::is_power_of_two(WarpSize), + static_assert(::rocprim::detail::is_power_of_two(VirtualWaveSize), "Logical warp size must be a power of two."); - ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT( - WarpSize <= ::rocprim::arch::wavefront::min_size(), - "Logical warp size cannot be larger than physical warp size."); private: using storage_type_ = typename ::rocprim::detail::empty_storage_type; public: + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE warp_store() + { + detail::check_virtual_wave_size(); + } + /// \brief Struct used to allocate a temporary memory that is required for thread /// communication during operations provided by related parallel primitive. /// @@ -171,7 +176,7 @@ class warp_store static_assert(std::is_convertible::value, "The type T must be such that an object of type OutputIterator " "can be dereferenced and then implicitly assigned from T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); block_store_direct_blocked(flat_id, output, items); } @@ -204,27 +209,67 @@ class warp_store static_assert(std::is_convertible::value, "The type T must be such that an object of type OutputIterator " "can be dereferenced and then implicitly assigned from T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); block_store_direct_blocked(flat_id, output, items, valid); } }; #ifndef DOXYGEN_SHOULD_SKIP_THIS -template< - class T, - unsigned int ItemsPerThread, - unsigned int WarpSize -> -class warp_store +template +class warp_store +{ +private: + using warp_store_wave32 = warp_store; + using warp_store_wave64 = warp_store; + using dispatch = ::rocprim::detail::dispatch_wave_size; + +public: + using storage_type = typename dispatch::storage_type; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + auto store(Args&&... args) + { + dispatch{}([](auto impl, auto&&... args) { impl.store(args...); }, args...); + } +}; + +template +class warp_store> { - static_assert(::rocprim::detail::is_power_of_two(WarpSize), + static_assert(::rocprim::detail::is_power_of_two(VirtualWaveSize), "Logical warp size must be a power of two."); - ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT( - WarpSize <= ::rocprim::arch::wavefront::min_size(), - "Logical warp size cannot be larger than physical warp size."); public: + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE warp_store() + { + detail::check_virtual_wave_size(); + } + using storage_type = typename ::rocprim::detail::empty_storage_type; template @@ -237,8 +282,8 @@ class warp_store::value, "The type T must be such that an object of type OutputIterator " "can be dereferenced and then implicitly assigned from T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); - block_store_direct_warp_striped(flat_id, output, items); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_warp_striped(flat_id, output, items); } template @@ -252,25 +297,31 @@ class warp_store::value, "The type T must be such that an object of type OutputIterator " "can be dereferenced and then implicitly assigned from T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); - block_store_direct_warp_striped(flat_id, output, items, valid); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_warp_striped(flat_id, output, items, valid); } }; -template< - class T, - unsigned int ItemsPerThread, - unsigned int WarpSize -> -class warp_store +template +class warp_store> { - static_assert(::rocprim::detail::is_power_of_two(WarpSize), + static_assert(::rocprim::detail::is_power_of_two(VirtualWaveSize), "Logical warp size must be a power of two."); - ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT( - WarpSize <= ::rocprim::arch::wavefront::min_size(), - "Logical warp size cannot be larger than physical warp size."); public: + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE warp_store() + { + detail::check_virtual_wave_size(); + } + using storage_type = typename ::rocprim::detail::empty_storage_type; ROCPRIM_DEVICE ROCPRIM_INLINE @@ -278,7 +329,7 @@ class warp_store(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); block_store_direct_blocked_vectorized(flat_id, output, items); } @@ -292,7 +343,7 @@ class warp_store::value, "The type T must be such that an object of type OutputIterator " "can be dereferenced and then implicitly assigned from T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); block_store_direct_blocked(flat_id, output, items); } @@ -307,28 +358,34 @@ class warp_store::value, "The type T must be such that an object of type OutputIterator " "can be dereferenced and then implicitly assigned from T."); - const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); block_store_direct_blocked(flat_id, output, items, valid); } }; -template< - class T, - unsigned int ItemsPerThread, - unsigned int WarpSize -> -class warp_store +template +class warp_store> { - static_assert(::rocprim::detail::is_power_of_two(WarpSize), + static_assert(::rocprim::detail::is_power_of_two(VirtualWaveSize), "Logical warp size must be a power of two."); - ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT( - WarpSize <= ::rocprim::arch::wavefront::min_size(), - "Logical warp size cannot be larger than physical warp size."); private: - using exchange_type = ::rocprim::warp_exchange; + using exchange_type = ::rocprim::warp_exchange; public: + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE warp_store() + { + detail::check_virtual_wave_size(); + } + using storage_type = typename exchange_type::storage_type; template @@ -342,8 +399,8 @@ class warp_store(); - block_store_direct_warp_striped(flat_id, output, items); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_warp_striped(flat_id, output, items); } template @@ -358,8 +415,8 @@ class warp_store(); - block_store_direct_warp_striped(flat_id, output, items, valid); + const unsigned int flat_id = ::rocprim::detail::logical_lane_id(); + block_store_direct_warp_striped(flat_id, output, items, valid); } }; diff --git a/scripts/autotune-search/main.py b/scripts/autotune-search/main.py index b65000644..f9e7fb256 100755 --- a/scripts/autotune-search/main.py +++ b/scripts/autotune-search/main.py @@ -52,7 +52,7 @@ ], }, "params": { - "LongBits": [6, 7, 8], + "RadixBits": [6, 7, 8], "BlockSize": [256], "ItemsPerThread": [7, 8, 13, 16, 17], "WarpSmallLWS": [8, 16, 32, 64], @@ -84,7 +84,7 @@ ], }, "params": { - "LongBits": [6, 7, 8], + "RadixBits": [6, 7, 8], "BlockSize": [256], "ItemsPerThread": [7, 8, 13, 16, 17], "WarpSmallLWS": [8, 16, 32, 64], diff --git a/scripts/autotune/create_optimization.py b/scripts/autotune/create_optimization.py index 1042bfb6d..853eea775 100755 --- a/scripts/autotune/create_optimization.py +++ b/scripts/autotune/create_optimization.py @@ -556,6 +556,14 @@ class AlgorithmDeviceTransform(Algorithm): def __init__(self, fallback_entries): Algorithm.__init__(self, fallback_entries) +class AlgorithmDeviceTransformPointer(Algorithm): + algorithm_name = "device_transform_pointer" + cpp_configuration_template_name = "transform_pointer_config_template" + config_selection_params = [ + SelectionType(name="value_type", is_optional=False, select_on_size_only=False)] + def __init__(self, fallback_entries): + Algorithm.__init__(self, fallback_entries) + class AlgorithmDevicePartitionTwoWayPredicate(Algorithm): algorithm_name = "device_partition_two_way_predicate" cpp_configuration_template_name = "partition_two_way_predicate_config_template" @@ -596,6 +604,14 @@ class AlgorithmDevicePartitionThreeWay(Algorithm): def __init__(self, fallback_entries): Algorithm.__init__(self, fallback_entries) +class AlgorithmDeviceSearchN(Algorithm): + algorithm_name = "device_search_n" + cpp_configuration_template_name = "search_n_config_template" + config_selection_params = [ + SelectionType(name="data_type", is_optional=False, select_on_size_only=False)] + def __init__(self, fallback_entries): + Algorithm.__init__(self, fallback_entries) + class AlgorithmDeviceSelectFlag(Algorithm): algorithm_name = "device_select_flag" cpp_configuration_template_name = "select_flag_config_template" @@ -717,6 +733,8 @@ def create_algorithm(algorithm_name: str, fallback_entries: List[FallbackCase]): return AlgorithmDeviceSegmentedRadixSort(fallback_entries) elif algorithm_name == 'device_transform': return AlgorithmDeviceTransform(fallback_entries) + elif algorithm_name == 'device_transform_pointer': + return AlgorithmDeviceTransformPointer(fallback_entries) elif algorithm_name == 'device_partition_two_way_predicate': return AlgorithmDevicePartitionTwoWayPredicate(fallback_entries) elif algorithm_name == 'device_partition_two_way_flag': @@ -727,6 +745,8 @@ def create_algorithm(algorithm_name: str, fallback_entries: List[FallbackCase]): return AlgorithmDevicePartitionPredicate(fallback_entries) elif algorithm_name == 'device_partition_three_way': return AlgorithmDevicePartitionThreeWay(fallback_entries) + elif algorithm_name == 'device_search_n': + return AlgorithmDeviceSearchN(fallback_entries) elif algorithm_name == 'device_select_flag': return AlgorithmDeviceSelectFlag(fallback_entries) elif algorithm_name == 'device_select_predicate': diff --git a/scripts/autotune/templates/config_template b/scripts/autotune/templates/config_template index 11443aae7..9d44e5cdc 100644 --- a/scripts/autotune/templates/config_template +++ b/scripts/autotune/templates/config_template @@ -22,7 +22,7 @@ #define {{ get_header_guard() }} #include "../../../config.hpp" -#include "../../../type_traits_interface.hpp" +#include "../../../type_traits.hpp" #include "../../config_types.hpp" #include "../device_config_helper.hpp" diff --git a/scripts/autotune/templates/search_n_config_template b/scripts/autotune/templates/search_n_config_template new file mode 100644 index 000000000..7ad5393b1 --- /dev/null +++ b/scripts/autotune/templates/search_n_config_template @@ -0,0 +1,19 @@ +{% extends "config_template" %} + +{% macro get_header_guard() %} +ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEARCH_N_HPP_ +{%- endmacro %} + +{% macro kernel_configuration(measurement) -%} +search_n_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, {{ measurement['cfg']['threshold'] }}> { }; +{%- endmacro %} + +{% macro general_case() -%} +template struct default_search_n_config : +default_search_n_config_base::type { }; +{%- endmacro %} + +{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} +// Based on {{ based_on_type }} +template struct default_search_n_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{%- endmacro %} diff --git a/scripts/autotune/templates/segmented_radix_sort_config_template b/scripts/autotune/templates/segmented_radix_sort_config_template index 6e6ef806b..4423befa4 100644 --- a/scripts/autotune/templates/segmented_radix_sort_config_template +++ b/scripts/autotune/templates/segmented_radix_sort_config_template @@ -6,7 +6,7 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_RADIX_SORT_HPP_ {% macro kernel_configuration(measurement) -%} segmented_radix_sort_config< - {{ measurement['cfg']['lrb'] }}, {{ measurement['cfg']['srb'] }}, + {{ measurement['cfg']['rb'] }}, kernel_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}>, typename std::conditional< {{ measurement['cfg']['wsc']['pa'] }}, diff --git a/scripts/autotune/templates/transform_pointer_config_template b/scripts/autotune/templates/transform_pointer_config_template new file mode 100644 index 000000000..efa1193df --- /dev/null +++ b/scripts/autotune/templates/transform_pointer_config_template @@ -0,0 +1,20 @@ +{% extends "config_template" %} + +{% macro get_header_guard() %} +ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_TRANSFORM_POINTER_HPP_ +{%- endmacro %} + +{% macro kernel_configuration(measurement) -%} +transform_pointer_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::{{ measurement['cfg']['lt'] }}> { }; +{%- endmacro %} + +{% macro general_case() -%} +template +struct default_transform_pointer_config : default_transform_pointer_config_base::type +{}; +{%- endmacro %} + +{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} +// Based on {{ based_on_type }} +template struct default_transform_pointer_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{%- endmacro %} diff --git a/test/common_test_header.hpp b/test/common_test_header.hpp index 02a8d17ef..fe845c4e2 100755 --- a/test/common_test_header.hpp +++ b/test/common_test_header.hpp @@ -81,9 +81,6 @@ #define INSTANTIATE_TYPED_TEST(test_suite_name, ...) \ INSTANTIATE_TYPED_TEST_EXPANDED(__LINE__, test_suite_name, __VA_ARGS__) -// C++17 or newer -#define CPP17 __cplusplus >= 201703L - #include #include #include diff --git a/test/extra/CMakeLists.txt b/test/extra/CMakeLists.txt index 14fceaa71..9dbbb5b80 100644 --- a/test/extra/CMakeLists.txt +++ b/test/extra/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -64,10 +64,8 @@ set(CMAKE_${PROJECT_LANG}_STANDARD_REQUIRED ON) set(CMAKE_${PROJECT_LANG}_EXTENSIONS OFF) set(CMAKE_${PROJECT_LANG}_FLAGS "${CMAKE_${PROJECT_LANG}_FLAGS} -Wall -Wextra -Werror") -if (CMAKE_CXX_STANDARD EQUAL 14) - message(WARNING "C++14 will be deprecated in the next major release") -elseif(NOT CMAKE_CXX_STANDARD EQUAL 17) - message(FATAL_ERROR "Only C++14 and C++17 are supported") +if(NOT CMAKE_CXX_STANDARD EQUAL 17) + message(FATAL_ERROR "Only C++17 is supported") endif() # Enable testing (ctest) diff --git a/test/rocprim/CMakeLists.txt b/test/rocprim/CMakeLists.txt index faebe247e..38bd109d5 100644 --- a/test/rocprim/CMakeLists.txt +++ b/test/rocprim/CMakeLists.txt @@ -243,10 +243,6 @@ function(add_rocprim_cpp_standard_test STANDARD EXTENSIONS TARGET_SUFFIX TEST_NA endif() endfunction() -function(add_rocprim_cpp17_test TEST_NAME TEST_SOURCES) - add_rocprim_cpp_standard_test(17 OFF "" ${TEST_NAME} ${TEST_SOURCES}) -endfunction() - # **************************************************************************** # Tests # **************************************************************************** @@ -264,6 +260,7 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") add_rocprim_test_parallel("rocprim.block_adjacent_difference" test_block_adjacent_difference.cpp.in) add_rocprim_test_parallel("rocprim.block_discontinuity" test_block_discontinuity.cpp.in) endif() +add_rocprim_test("rocprim.bit_cast" test_bit_cast.cpp) add_rocprim_test("rocprim.block_exchange" test_block_exchange.cpp) add_rocprim_test("rocprim.block_histogram" test_block_histogram.cpp) add_rocprim_test("rocprim.block_load_store" test_block_load_store.cpp) @@ -289,8 +286,8 @@ add_rocprim_test("rocprim.device_histogram" test_device_histogram.cpp) add_rocprim_test("rocprim.device_merge" test_device_merge.cpp) add_rocprim_test("rocprim.device_merge_inplace" test_device_merge_inplace.cpp) add_rocprim_test("rocprim.device_merge_sort" test_device_merge_sort.cpp) -add_rocprim_cpp17_test("rocprim.nth_element" test_device_nth_element.cpp) -add_rocprim_cpp17_test("rocprim.device_partial_sort" test_device_partial_sort.cpp) +add_rocprim_test("rocprim.nth_element" test_device_nth_element.cpp) +add_rocprim_test("rocprim.device_partial_sort" test_device_partial_sort.cpp) add_rocprim_test("rocprim.device_partition" test_device_partition.cpp) add_rocprim_test_parallel("rocprim.device_radix_sort" test_device_radix_sort.cpp.in) add_rocprim_test("rocprim.device_reduce_by_key" test_device_reduce_by_key.cpp) @@ -314,8 +311,6 @@ add_rocprim_test("rocprim.thread" test_thread.cpp) add_rocprim_test("rocprim.thread_algos" test_thread_algos.cpp) add_rocprim_test("rocprim.utils_sort_checker" test_utils_sort_checker.cpp) add_rocprim_test("rocprim.transform_iterator" test_transform_iterator.cpp) -# add_rocprim_cpp_standard_test(14 OFF "_cpp14" "rocprim.type_traits_interface" test_type_traits_interface.cpp) -# add_rocprim_cpp_standard_test(14 ON "_gnupp14" "rocprim.type_traits_interface" test_type_traits_interface.cpp) add_rocprim_cpp_standard_test(17 OFF "_cpp17" "rocprim.type_traits_interface" test_type_traits_interface.cpp) add_rocprim_cpp_standard_test(17 ON "_gnupp17" "rocprim.type_traits_interface" test_type_traits_interface.cpp) add_rocprim_cpp_standard_test(20 OFF "_cpp20" "rocprim.type_traits_interface" test_type_traits_interface.cpp) @@ -331,6 +326,7 @@ add_rocprim_test("rocprim.warp_scan" test_warp_scan.cpp) add_rocprim_test("rocprim.warp_sort" test_warp_sort.cpp) add_rocprim_test("rocprim.warp_store" test_warp_store.cpp) add_rocprim_test("rocprim.zip_iterator" test_zip_iterator.cpp) +add_rocprim_test("rocprim.accumulator_t" test_accumulator_t.cpp) if(NOT WIN32) # Linking tests check if all external rocPRIM symbols are in the inline namespace, kernel are not diff --git a/test/rocprim/test_accumulator_t.cpp b/test/rocprim/test_accumulator_t.cpp new file mode 100644 index 000000000..bf821359f --- /dev/null +++ b/test/rocprim/test_accumulator_t.cpp @@ -0,0 +1,149 @@ +// MIT License +// +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "../common_test_header.hpp" + +#include "../../common/utils_custom_type.hpp" + +#include +#include +#include +#include + +#include +#include + +template +struct const_ref_op +{ + ROCPRIM_INLINE ROCPRIM_HOST_DEVICE + constexpr const T& + operator()([[maybe_unused]] T a, [[maybe_unused]] T b) + { + return value; + } + +private: + static T value; +}; + +// Params for tests +template +struct AccumulatorParams +{ + using input_type = InputType; + using op_type = ScanOp; +}; + +template +class RocprimAccumulatorTests : public ::testing::Test +{ +public: + using input_type = typename Params::input_type; + using op_type = typename Params::op_type; +}; + +using input_types = ::testing::Types, + ::common::custom_type, + ::common::custom_type, + ::common::custom_type, + ::common::custom_type>; + +template +using binary_ops_template = ::testing::Types<::rocprim::less, + ::rocprim::less_equal, + ::rocprim::greater, + ::rocprim::greater_equal, + ::rocprim::equal_to, + ::rocprim::not_equal_to, + ::rocprim::plus, + ::rocprim::minus, + ::rocprim::multiplies, + ::rocprim::maximum, + ::rocprim::minimum, + const_ref_op>; + +template +struct FlattenHelper; + +template +struct FlattenHelper<::testing::Types> +{ + using type = ::testing::Types; +}; + +template +struct FlattenHelper<::testing::Types, ::testing::Types, Rest...> +{ + using type = typename FlattenHelper<::testing::Types, Rest...>::type; +}; + +template +struct GenerateEachInputTypeParams; + +template +struct GenerateEachInputTypeParams> +{ + using type = ::testing::Types...>; +}; + +template +struct GenerateAllParams; + +template +struct GenerateAllParams<::testing::Types> +{ + using type = typename FlattenHelper< + typename GenerateEachInputTypeParams>::type...>::type; +}; + +using RocprimAccumulatorTestsParams = GenerateAllParams::type; + +TYPED_TEST_SUITE(RocprimAccumulatorTests, RocprimAccumulatorTestsParams); + +// Test `accumulator_t` with `const_ref_op` and all binary operators in rocPRIM +// This is tested in compile time, so it can be compiled if it's compatible with all binary operators. +TYPED_TEST(RocprimAccumulatorTests, PointerToAccType) +{ + using T = typename TestFixture::input_type; + using Op = typename TestFixture::op_type; + + using acc_type = ::rocprim::accumulator_t; + [[maybe_unused]] acc_type* unused = nullptr; +} diff --git a/test/rocprim/test_bit_cast.cpp b/test/rocprim/test_bit_cast.cpp new file mode 100644 index 000000000..14537dedb --- /dev/null +++ b/test/rocprim/test_bit_cast.cpp @@ -0,0 +1,213 @@ +// MIT License +// +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PUrocprimOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "../common_test_header.hpp" + +namespace +{ +// Non-trivial types for illegal combination tests +struct NonTrivial32 +{ + int a; + NonTrivial32() : a(0) {} + NonTrivial32(const NonTrivial32& other) : a(other.a) {} +}; + +struct NonTrivial64 +{ + long long a; + NonTrivial64() : a(0) {} + NonTrivial64(const NonTrivial64& other) : a(other.a) {} +}; + +struct NonTrivial128 +{ + rocprim::int128_t a; + NonTrivial128() : a(0) {} + NonTrivial128(const NonTrivial128& other) : a(other.a) {} +}; + +// Self-defined groups for +struct Group32_I +{ + char a; + uint8_t b; + short c; +}; + +struct Group32_II +{ + unsigned short a; + rocprim::half b; +}; + +struct Group64_I +{ + uint8_t a; + char b; + short c; + int d; +}; + +struct Group64_II +{ + long long a; +}; + +static_assert(rocprim::detail::is_valid_bit_cast, + "input types (int, float) must be valid!"); +static_assert(!rocprim::detail::is_valid_bit_cast, + "input types (int, double) must be invalid!"); +static_assert(!rocprim::detail::is_valid_bit_cast, + "input types (int, NonTrivial32) must be invalid!"); +static_assert(rocprim::detail::is_valid_bit_cast, + "input types (int, Group32_I) must be valid!"); +} // namespace + +template +void TestBitCastCombinationImpl(const Source& source) +{ + if constexpr(rocprim::detail::is_valid_bit_cast) + { + Destination dest_memcpy; + std::memcpy(&dest_memcpy, &source, sizeof(Destination)); + Destination dest_bitcast = rocprim::detail::bit_cast(source); + ASSERT_EQ(std::memcmp(&dest_memcpy, &dest_bitcast, sizeof(Destination)), 0); + } +} + +template +void TestBitCastForDestinationsImpl(const Source& source, std::index_sequence) +{ + (void)std::initializer_list{ + (TestBitCastCombinationImpl>(source), 0)...}; +} + +template +void TestBitCastForDestinations(const Source& source) +{ + TestBitCastForDestinationsImpl( + source, + std::make_index_sequence>{}); +} + +// Types in google test arrray +using AllTypes = ::testing::Types< + // 8 bit + char, + unsigned char, + int8_t, + uint8_t, + // 16 bit + short, + unsigned short, + int16_t, + uint16_t, + rocprim::half, + rocprim::bfloat16, + // 32 bit + int, + unsigned int, + int32_t, + uint32_t, + float, + // 64 bit + long long, + unsigned long long, + int64_t, + uint64_t, + double, + // 128 bit + rocprim::int128_t, + rocprim::uint128_t, + // non trivial + NonTrivial32, + NonTrivial64, + NonTrivial128, + // self-defined structs + Group32_I, + Group32_II, + Group64_I, + Group64_II>; + +// Types in tuple for interation +using AllTypesTuple = std::tuple< + // 8 bit + char, + unsigned char, + int8_t, + uint8_t, + // 16 bit + short, + unsigned short, + int16_t, + uint16_t, + rocprim::half, + rocprim::bfloat16, + // 32 bit + int, + unsigned int, + int32_t, + uint32_t, + float, + // 64 bit + long long, + unsigned long long, + int64_t, + uint64_t, + double, + // 128 bit + rocprim::int128_t, + rocprim::uint128_t, + // non trivial + NonTrivial32, + NonTrivial64, + NonTrivial128, + // self-defined structs + Group32_I, + Group32_II, + Group64_I, + Group64_II>; + +template +class BitCastPairTest : public ::testing::Test +{ +public: + using SourceType = Source; +}; + +TYPED_TEST_SUITE(BitCastPairTest, AllTypes); + +TYPED_TEST(BitCastPairTest, BitCastPairTest) +{ + using Source = typename TestFixture::SourceType; + + unsigned char buffer[sizeof(Source)]; + for(size_t i = 0; i < sizeof(Source); ++i) + { + buffer[i] = static_cast(rand() & 0xFF); + } + Source source; + std::memcpy(reinterpret_cast(&source), buffer, sizeof(Source)); + + TestBitCastForDestinations(source); +} diff --git a/test/rocprim/test_block_adjacent_difference.cpp.in b/test/rocprim/test_block_adjacent_difference.cpp.in index f0d1eded7..438cc23b6 100644 --- a/test/rocprim/test_block_adjacent_difference.cpp.in +++ b/test/rocprim/test_block_adjacent_difference.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -37,29 +37,34 @@ // Start stamping out tests struct RocprimBlockAdjacentDifference; -#cmakedefine ROCPRIM_TEST_SLICE @ROCPRIM_TEST_SLICE@ - -#if ROCPRIM_TEST_SLICE == 0 - -struct Integral; -#define suite_name RocprimBlockAdjacentDifference -#define warp_params BlockDiscParamsIntegral -#define name_suffix Integral - -#elif ROCPRIM_TEST_SLICE == 1 - -struct Floating; -#define suite_name RocprimBlockAdjacentDifference -#define warp_params BlockDiscParamsFloating -#define name_suffix Floating +#if !_CLANGD + #cmakedefine ROCPRIM_TEST_SLICE @ROCPRIM_TEST_SLICE@ +#endif -#elif ROCPRIM_TEST_SLICE == 2 +#if ROCPRIM_TEST_SLICE == 0 || _CLANGD + struct Integral; + #define suite_name RocprimBlockAdjacentDifference + #define warp_params BlockDiscParamsIntegral + #define name_suffix Integral +#endif -struct FloatingHalf; -#define suite_name RocprimBlockAdjacentDifference -#define warp_params BlockDiscParamsFloatingHalf -#define name_suffix FloatingHalf +#if ROCPRIM_TEST_SLICE == 1 || _CLANGD + struct Floating; + #define suite_name RocprimBlockAdjacentDifference + #define warp_params BlockDiscParamsFloating + #define name_suffix Floating +#endif +#if ROCPRIM_TEST_SLICE == 2 || _CLANGD + struct FloatingHalf; + #define suite_name RocprimBlockAdjacentDifference + #define warp_params BlockDiscParamsFloatingHalf + #define name_suffix FloatingHalf #endif -#include "test_block_adjacent_difference.hpp" +#if !_CLANGD + // When using clangd, the '.cpp.in' file is already included + // in the header. To prevent recursive includes, the header is + // not supposed to be added. + #include "test_block_adjacent_difference.hpp" +#endif diff --git a/test/rocprim/test_block_adjacent_difference.hpp b/test/rocprim/test_block_adjacent_difference.hpp index cb4047839..1cf983c40 100644 --- a/test/rocprim/test_block_adjacent_difference.hpp +++ b/test/rocprim/test_block_adjacent_difference.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,57 +20,17 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. +#ifdef _CLANGD + // When using clangd, to allow the language server to function properly, + // some context of the template source is required to allow the language + // server to function properly. + #include "test_block_adjacent_difference.cpp.in" +#endif + test_suite_type_def(suite_name, name_suffix) typed_test_suite_def(RocprimBlockAdjacentDifference, name_suffix, warp_params); -typed_test_def(RocprimBlockAdjacentDifference, name_suffix, FlagHeads) -{ - using type = typename TestFixture::params::input_type; - using flag_type = typename TestFixture::params::output_type; - using flag_op_type_1 = rocprim::less; - using flag_op_type_2 = rocprim::equal_to; - using flag_op_type_3 = rocprim::greater; - using flag_op_type_4 = rocprim::not_equal_to; - constexpr size_t block_size = TestFixture::params::block_size; - - static_for<0, 2, type, flag_type, flag_op_type_1, 0, block_size>::run(); - static_for<2, 4, type, flag_type, flag_op_type_2, 0, block_size>::run(); - static_for<4, 6, type, flag_type, flag_op_type_3, 0, block_size>::run(); - static_for<6, n_items, type, flag_type, flag_op_type_4, 0, block_size>::run(); -} - -typed_test_def(RocprimBlockAdjacentDifference, name_suffix, FlagTails) -{ - using type = typename TestFixture::params::input_type; - using flag_type = typename TestFixture::params::output_type; - using flag_op_type_1 = rocprim::less; - using flag_op_type_2 = rocprim::equal_to; - using flag_op_type_3 = rocprim::greater; - using flag_op_type_4 = rocprim::not_equal_to; - constexpr size_t block_size = TestFixture::params::block_size; - static_for<0, 2, type, flag_type, flag_op_type_1, 1, block_size>::run(); - static_for<2, 4, type, flag_type, flag_op_type_2, 1, block_size>::run(); - static_for<4, 6, type, flag_type, flag_op_type_3, 1, block_size>::run(); - static_for<6, n_items, type, flag_type, flag_op_type_4, 1, block_size>::run(); -} - -typed_test_def(RocprimBlockAdjacentDifference, name_suffix, FlagHeadsAndTails) -{ - using type = typename TestFixture::params::input_type; - using flag_type = typename TestFixture::params::output_type; - using flag_op_type_1 = rocprim::less; - using flag_op_type_2 = rocprim::equal_to; - using flag_op_type_3 = rocprim::greater; - using flag_op_type_4 = rocprim::not_equal_to; - constexpr size_t block_size = TestFixture::params::block_size; - - static_for<0, 2, type, flag_type, flag_op_type_1, 2, block_size>::run(); - static_for<2, 4, type, flag_type, flag_op_type_2, 2, block_size>::run(); - static_for<4, 6, type, flag_type, flag_op_type_3, 2, block_size>::run(); - static_for<6, n_items, type, flag_type, flag_op_type_4, 2, block_size>::run(); -} - typed_test_def(RocprimBlockAdjacentDifference, name_suffix, SubtractLeft) { using T = typename TestFixture::params::input_type; @@ -82,9 +42,9 @@ typed_test_def(RocprimBlockAdjacentDifference, name_suffix, SubtractLeft) constexpr size_t block_size = TestFixture::params::block_size; // clang-format off - static_for<0, 2, T, T, op_type_1, 3, block_size>::run(); - static_for<2, 4, T, T, op_type_2, 3, block_size>::run(); - static_for<4, n_items, T, T, op_type_3, 3, block_size>::run(); + static_for<0, 2, T, T, op_type_1, 0, block_size>::run(); + static_for<2, 4, T, T, op_type_2, 0, block_size>::run(); + static_for<4, n_items, T, T, op_type_3, 0, block_size>::run(); // clang-format on } @@ -99,9 +59,9 @@ typed_test_def(RocprimBlockAdjacentDifference, name_suffix, SubtractRight) constexpr size_t block_size = TestFixture::params::block_size; // clang-format off - static_for<0, 2, T, T, op_type_1, 4, block_size>::run(); - static_for<2, 4, T, T, op_type_2, 4, block_size>::run(); - static_for<4, n_items, T, T, op_type_3, 4, block_size>::run(); + static_for<0, 2, T, T, op_type_1, 1, block_size>::run(); + static_for<2, 4, T, T, op_type_2, 1, block_size>::run(); + static_for<4, n_items, T, T, op_type_3, 1, block_size>::run(); // clang-format on } @@ -116,9 +76,9 @@ typed_test_def(RocprimBlockAdjacentDifference, name_suffix, SubtractLeftPartial) constexpr size_t block_size = TestFixture::params::block_size; // clang-format off - static_for<0, 2, T, T, op_type_1, 5, block_size>::run(); - static_for<2, 4, T, T, op_type_2, 5, block_size>::run(); - static_for<4, n_items, T, T, op_type_3, 5, block_size>::run(); + static_for<0, 2, T, T, op_type_1, 2, block_size>::run(); + static_for<2, 4, T, T, op_type_2, 2, block_size>::run(); + static_for<4, n_items, T, T, op_type_3, 2, block_size>::run(); // clang-format on } @@ -133,8 +93,8 @@ typed_test_def(RocprimBlockAdjacentDifference, name_suffix, SubtractRightPartial constexpr size_t block_size = TestFixture::params::block_size; // clang-format off - static_for<0, 2, T, T, op_type_1, 6, block_size>::run(); - static_for<2, 4, T, T, op_type_2, 6, block_size>::run(); - static_for<4, n_items, T, T, op_type_3, 6, block_size>::run(); + static_for<0, 2, T, T, op_type_1, 3, block_size>::run(); + static_for<2, 4, T, T, op_type_2, 3, block_size>::run(); + static_for<4, n_items, T, T, op_type_3, 3, block_size>::run(); // clang-format on } diff --git a/test/rocprim/test_block_adjacent_difference.kernels.hpp b/test/rocprim/test_block_adjacent_difference.kernels.hpp index 6b5cd6f3b..05f17f779 100644 --- a/test/rocprim/test_block_adjacent_difference.kernels.hpp +++ b/test/rocprim/test_block_adjacent_difference.kernels.hpp @@ -23,10 +23,16 @@ #ifndef TEST_BLOCK_ADJACENT_DIFFERENCE_KERNELS_HPP_ #define TEST_BLOCK_ADJACENT_DIFFERENCE_KERNELS_HPP_ +#include "test_utils.hpp" +#include "test_utils_types.hpp" + #include "../common_test_header.hpp" #include "../../common/utils_device_ptr.hpp" -#include "test_utils.hpp" + +#include +#include +#include // Host (CPU) implementaions of the wrapping function that allows to pass 3 args template @@ -52,46 +58,6 @@ struct test_op } }; -template -__global__ __launch_bounds__(BlockSize, ROCPRIM_DEFAULT_MIN_WARPS_PER_EU) void flag_heads_kernel( - Type* device_input, long long* device_heads) -{ - const unsigned int lid = threadIdx.x; - const unsigned int items_per_block = BlockSize * ItemsPerThread; - const unsigned int block_offset = blockIdx.x * items_per_block; - - Type input[ItemsPerThread]; - rocprim::block_load_direct_blocked(lid, device_input + block_offset, input); - - rocprim::block_adjacent_difference adjacent_difference; - __shared__ typename decltype(adjacent_difference)::storage_type storage; - - FlagType head_flags[ItemsPerThread]; - - // Still need to test it even tough its deprecated - ROCPRIM_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") - if(blockIdx.x % 2 == 1) - { - const Type tile_predecessor_item = device_input[block_offset - 1]; - adjacent_difference.flag_heads(head_flags, - tile_predecessor_item, - input, - FlagOpType(), - storage); - } - else - { - adjacent_difference.flag_heads(head_flags, input, FlagOpType(), storage); - } - ROCPRIM_CLANG_SUPPRESS_WARNING_POP - - rocprim::block_store_direct_blocked(lid, device_heads + block_offset, head_flags); -} - template -__global__ __launch_bounds__(BlockSize, ROCPRIM_DEFAULT_MIN_WARPS_PER_EU) void flag_tails_kernel( - Type* device_input, long long* device_tails) -{ - const unsigned int lid = threadIdx.x; - const unsigned int items_per_block = BlockSize * ItemsPerThread; - const unsigned int block_offset = blockIdx.x * items_per_block; - - Type input[ItemsPerThread]; - rocprim::block_load_direct_blocked(lid, device_input + block_offset, input); - - rocprim::block_adjacent_difference adjacent_difference; - __shared__ typename decltype(adjacent_difference)::storage_type storage; - - FlagType tail_flags[ItemsPerThread]; - - // Still need to test it even tough its deprecated - ROCPRIM_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") - if(blockIdx.x % 2 == 0) - { - const Type tile_successor_item = device_input[block_offset + items_per_block]; - adjacent_difference.flag_tails(tail_flags, - tile_successor_item, - input, - FlagOpType(), - storage); - } - else - { - adjacent_difference.flag_tails(tail_flags, input, FlagOpType(), storage); - } - ROCPRIM_CLANG_SUPPRESS_WARNING_POP - - rocprim::block_store_direct_blocked(lid, device_tails + block_offset, tail_flags); -} - -template -__global__ __launch_bounds__( - BlockSize, - ROCPRIM_DEFAULT_MIN_WARPS_PER_EU) void flag_heads_and_tails_kernel(Type* device_input, - long long* device_heads, - long long* device_tails) -{ - const unsigned int lid = threadIdx.x; - const unsigned int items_per_block = BlockSize * ItemsPerThread; - const unsigned int block_offset = blockIdx.x * items_per_block; - - Type input[ItemsPerThread]; - rocprim::block_load_direct_blocked(lid, device_input + block_offset, input); - - rocprim::block_adjacent_difference adjacent_difference; - __shared__ typename decltype(adjacent_difference)::storage_type storage; - - FlagType head_flags[ItemsPerThread]; - FlagType tail_flags[ItemsPerThread]; - - // Still need to test it even tough its deprecated - ROCPRIM_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated") - if(blockIdx.x % 4 == 0) - { - const Type tile_successor_item = device_input[block_offset + items_per_block]; - adjacent_difference.flag_heads_and_tails(head_flags, - tail_flags, - tile_successor_item, - input, - FlagOpType(), - storage); - } - else if(blockIdx.x % 4 == 1) - { - const Type tile_predecessor_item = device_input[block_offset - 1]; - const Type tile_successor_item = device_input[block_offset + items_per_block]; - adjacent_difference.flag_heads_and_tails(head_flags, - tile_predecessor_item, - tail_flags, - tile_successor_item, - input, - FlagOpType(), - storage); - } - else if(blockIdx.x % 4 == 2) - { - const Type tile_predecessor_item = device_input[block_offset - 1]; - adjacent_difference.flag_heads_and_tails(head_flags, - tile_predecessor_item, - tail_flags, - input, - FlagOpType(), - storage); - } - else if(blockIdx.x % 4 == 3) - { - adjacent_difference.flag_heads_and_tails(head_flags, - tail_flags, - input, - FlagOpType(), - storage); - } - ROCPRIM_CLANG_SUPPRESS_WARNING_POP - - rocprim::block_store_direct_blocked(lid, device_heads + block_offset, head_flags); - rocprim::block_store_direct_blocked(lid, device_tails + block_offset, tail_flags); -} - -template -auto test_block_adjacent_difference() -> typename std::enable_if::type -{ - using type = Type; - // std::vector is a special case that will cause an error in hipMemcpy - // rocprim::half/rocprim::bfloat16 are special cases that cannot be compared '==' - // in ASSERT_EQ - using stored_flag_type = typename std::conditional< - std::is_same::value, - int, - typename std::conditional::value - || std::is_same::value, - float, - FlagType>::type>::type; - using flag_type = FlagType; - using flag_op_type = FlagOpType; - static constexpr size_t block_size = BlockSize; - static constexpr size_t items_per_thread = ItemsPerThread; - static constexpr size_t items_per_block = block_size * items_per_thread; - static constexpr size_t size = items_per_block * 20; - static constexpr size_t grid_size = size / items_per_block; - - SCOPED_TRACE(testing::Message() << "items_per_block = " << items_per_block); - SCOPED_TRACE(testing::Message() << "size = " << size); - SCOPED_TRACE(testing::Message() << "grid_size = " << grid_size); - - // Given block size not supported - if(block_size > test_utils::get_max_block_size()) - { - return; - } - - for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) - { - unsigned int seed_value - = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; - SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); - - // Generate data - std::vector input - = test_utils::get_random_data_wrapped(size, 0, 10, seed_value); - - // Calculate expected results on host - std::vector expected_heads(size); - flag_op_type flag_op; - for(size_t bi = 0; bi < size / items_per_block; bi++) - { - for(size_t ii = 0; ii < items_per_block; ii++) - { - const size_t i = bi * items_per_block + ii; - if(ii == 0) - { - expected_heads[i] = bi % 2 == 1 - ? apply(flag_op, - input[i - 1], - input[i], - ii) - : stored_flag_type(true); - } - else - { - expected_heads[i] - = apply(flag_op, input[i - 1], input[i], ii); - } - } - } - - // Preparing Device - common::device_ptr device_input(input); - common::device_ptr device_heads(size); - - // Running kernel - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - flag_heads_kernel), - dim3(grid_size), - dim3(block_size), - 0, - 0, - device_input.get(), - device_heads.get()); - HIP_CHECK(hipGetLastError()); - HIP_CHECK(hipDeviceSynchronize()); - - // Reading results - const auto heads = device_heads.load(); - - // Validating results - for(size_t i = 0; i < size; i++) - { - ASSERT_EQ(heads[i], expected_heads[i]); - } - } -} - -template -auto test_block_adjacent_difference() -> typename std::enable_if::type -{ - using type = Type; - // std::vector is a special case that will cause an error in hipMemcpy - // rocprim::half/rocprim::bfloat16 are special cases that cannot be compared '==' - // in ASSERT_EQ - using stored_flag_type = typename std::conditional< - std::is_same::value, - int, - typename std::conditional::value - || std::is_same::value, - float, - FlagType>::type>::type; - using flag_type = FlagType; - using flag_op_type = FlagOpType; - static constexpr size_t block_size = BlockSize; - static constexpr size_t items_per_thread = ItemsPerThread; - static constexpr size_t items_per_block = block_size * items_per_thread; - static constexpr size_t size = items_per_block * 20; - static constexpr size_t grid_size = size / items_per_block; - - SCOPED_TRACE(testing::Message() << "items_per_block = " << items_per_block); - SCOPED_TRACE(testing::Message() << "size = " << size); - SCOPED_TRACE(testing::Message() << "grid_size = " << grid_size); - - // Given block size not supported - if(block_size > test_utils::get_max_block_size()) - { - return; - } - - for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) - { - unsigned int seed_value - = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; - SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); - - // Generate data - std::vector input - = test_utils::get_random_data_wrapped(size, 0, 10, seed_value); - - // Calculate expected results on host - std::vector expected_tails(size); - flag_op_type flag_op; - for(size_t bi = 0; bi < size / items_per_block; bi++) - { - for(size_t ii = 0; ii < items_per_block; ii++) - { - const size_t i = bi * items_per_block + ii; - if(ii == items_per_block - 1) - { - expected_tails[i] = bi % 2 == 0 - ? apply(flag_op, - input[i], - input[i + 1], - ii + 1) - : stored_flag_type(true); - } - else - { - expected_tails[i] = apply(flag_op, - input[i], - input[i + 1], - ii + 1); - } - } - } - - // Preparing Device - common::device_ptr device_input(input); - common::device_ptr device_tails(size); - - // Running kernel - hipLaunchKernelGGL( - HIP_KERNEL_NAME( - flag_tails_kernel), - dim3(grid_size), - dim3(block_size), - 0, - 0, - device_input.get(), - device_tails.get()); - HIP_CHECK(hipGetLastError()); - HIP_CHECK(hipDeviceSynchronize()); - - // Reading results - const auto tails = device_tails.load(); - - // Validating results - for(size_t i = 0; i < size; i++) - { - ASSERT_EQ(tails[i], expected_tails[i]); - } - } -} - -template -auto test_block_adjacent_difference() -> typename std::enable_if::type -{ - using type = Type; - // std::vector is a special case that will cause an error in hipMemcpy - // rocprim::half/rocprim::bfloat16 are special cases that cannot be compared '==' - // in ASSERT_EQ - using stored_flag_type = typename std::conditional< - std::is_same::value, - int, - typename std::conditional::value - || std::is_same::value, - float, - FlagType>::type>::type; - using flag_type = FlagType; - using flag_op_type = FlagOpType; - static constexpr size_t block_size = BlockSize; - static constexpr size_t items_per_thread = ItemsPerThread; - static constexpr size_t items_per_block = block_size * items_per_thread; - static constexpr size_t size = items_per_block * 20; - static constexpr size_t grid_size = size / items_per_block; - - SCOPED_TRACE(testing::Message() << "items_per_block = " << items_per_block); - SCOPED_TRACE(testing::Message() << "size = " << size); - SCOPED_TRACE(testing::Message() << "grid_size = " << grid_size); - - // Given block size not supported - if(block_size > test_utils::get_max_block_size()) - { - return; - } - - for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) - { - unsigned int seed_value - = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; - SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); - - // Generate data - std::vector input - = test_utils::get_random_data_wrapped(size, 0, 10, seed_value); - - // Calculate expected results on host - std::vector expected_heads(size); - std::vector expected_tails(size); - flag_op_type flag_op; - for(size_t bi = 0; bi < size / items_per_block; bi++) - { - for(size_t ii = 0; ii < items_per_block; ii++) - { - const size_t i = bi * items_per_block + ii; - if(ii == 0) - { - expected_heads[i] = (bi % 4 == 1 || bi % 4 == 2) - ? apply(flag_op, - input[i - 1], - input[i], - ii) - : stored_flag_type(true); - } - else - { - expected_heads[i] - = apply(flag_op, input[i - 1], input[i], ii); - } - if(ii == items_per_block - 1) - { - expected_tails[i] = (bi % 4 == 0 || bi % 4 == 1) - ? apply(flag_op, - input[i], - input[i + 1], - ii + 1) - : stored_flag_type(true); - } - else - { - expected_tails[i] = apply(flag_op, - input[i], - input[i + 1], - ii + 1); - } - } - } - - // Preparing Device - common::device_ptr device_input(input); - common::device_ptr device_heads(size); - common::device_ptr device_tails(size); - - // Running kernel - hipLaunchKernelGGL(HIP_KERNEL_NAME(flag_heads_and_tails_kernel), - dim3(grid_size), - dim3(block_size), - 0, - 0, - device_input.get(), - device_heads.get(), - device_tails.get()); - HIP_CHECK(hipGetLastError()); - HIP_CHECK(hipDeviceSynchronize()); - - // Reading results - const auto heads = device_heads.load(); - const auto tails = device_tails.load(); - // Validating results - for(size_t i = 0; i < size; i++) - { - ASSERT_EQ(heads[i], expected_heads[i]); - ASSERT_EQ(tails[i], expected_tails[i]); - } - } -} - template -auto test_block_adjacent_difference() -> typename std::enable_if::type +auto test_block_adjacent_difference() -> typename std::enable_if::type { using stored_type = std::conditional_t::value, int, Output>; @@ -779,7 +309,7 @@ template -auto test_block_adjacent_difference() -> typename std::enable_if::type +auto test_block_adjacent_difference() -> typename std::enable_if::type { using stored_type = std::conditional_t::value, int, Output>; @@ -866,7 +396,7 @@ template -auto test_block_adjacent_difference() -> typename std::enable_if::type +auto test_block_adjacent_difference() -> typename std::enable_if::type { using stored_type = std::conditional_t::value, int, Output>; @@ -968,7 +498,7 @@ template -auto test_block_adjacent_difference() -> typename std::enable_if::type +auto test_block_adjacent_difference() -> typename std::enable_if::type { using stored_type = std::conditional_t::value, int, Output>; diff --git a/test/rocprim/test_block_discontinuity.cpp.in b/test/rocprim/test_block_discontinuity.cpp.in index 3fd743f5d..6bff2f678 100644 --- a/test/rocprim/test_block_discontinuity.cpp.in +++ b/test/rocprim/test_block_discontinuity.cpp.in @@ -40,29 +40,34 @@ // Start stamping out tests struct RocprimBlockDiscontinuity; -#cmakedefine ROCPRIM_TEST_SLICE @ROCPRIM_TEST_SLICE@ - -#if ROCPRIM_TEST_SLICE == 0 - -struct Integral; -#define suite_name RocprimBlockDiscontinuity -#define warp_params BlockDiscParamsIntegral -#define name_suffix Integral - -#elif ROCPRIM_TEST_SLICE == 1 - -struct Floating; -#define suite_name RocprimBlockDiscontinuity -#define warp_params BlockDiscParamsFloating -#define name_suffix Floating +#if !_CLANGD + #cmakedefine ROCPRIM_TEST_SLICE @ROCPRIM_TEST_SLICE@ +#endif -#elif ROCPRIM_TEST_SLICE == 2 +#if ROCPRIM_TEST_SLICE == 0 || _CLANGD + struct Integral; + #define suite_name RocprimBlockDiscontinuity + #define warp_params BlockDiscParamsIntegral + #define name_suffix Integral +#endif -struct FloatingHalf; -#define suite_name RocprimBlockDiscontinuity -#define warp_params BlockDiscParamsFloatingHalf -#define name_suffix FloatingHalf +#if ROCPRIM_TEST_SLICE == 1 || _CLANGD + struct Floating; + #define suite_name RocprimBlockDiscontinuity + #define warp_params BlockDiscParamsFloating + #define name_suffix Floating +#endif +#if ROCPRIM_TEST_SLICE == 2 || _CLANGD + struct FloatingHalf; + #define suite_name RocprimBlockDiscontinuity + #define warp_params BlockDiscParamsFloatingHalf + #define name_suffix FloatingHalf #endif -#include "test_block_discontinuity.hpp" +#if !_CLANGD + // When using clangd, the '.cpp.in' file is already included + // in the header. To prevent recursive includes, the header is + // not supposed to be added. + #include "test_block_discontinuity.hpp" +#endif diff --git a/test/rocprim/test_block_discontinuity.hpp b/test/rocprim/test_block_discontinuity.hpp index a417d163c..84faa6d81 100644 --- a/test/rocprim/test_block_discontinuity.hpp +++ b/test/rocprim/test_block_discontinuity.hpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,6 +20,13 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. +#ifdef _CLANGD + // When using clangd, to allow the language server to function properly, + // some context of the template source is required to allow the language + // server to function properly. + #include "test_block_discontinuity.cpp.in" +#endif + test_suite_type_def(suite_name, name_suffix) typed_test_suite_def(RocprimBlockDiscontinuity, name_suffix, warp_params); diff --git a/test/rocprim/test_block_histogram.kernels.hpp b/test/rocprim/test_block_histogram.kernels.hpp index 8351ec426..bde60144d 100644 --- a/test/rocprim/test_block_histogram.kernels.hpp +++ b/test/rocprim/test_block_histogram.kernels.hpp @@ -38,7 +38,6 @@ #include #include #include -#include #include #include diff --git a/test/rocprim/test_block_radix_rank.hpp b/test/rocprim/test_block_radix_rank.hpp index e078fdcfa..ffd4a1930 100644 --- a/test/rocprim/test_block_radix_rank.hpp +++ b/test/rocprim/test_block_radix_rank.hpp @@ -68,11 +68,12 @@ template -__global__ __launch_bounds__(BlockSize) void rank_kernel(const T* const items_input, - unsigned int* const ranks_output, - const bool descending, - const unsigned int start_bit, - const unsigned int radix_bits) +__global__ __launch_bounds__(BlockSize) +void rank_kernel(const T* const items_input, + unsigned int* const ranks_output, + const bool descending, + const unsigned int start_bit, + const unsigned int radix_bits) { using block_rank_type = rocprim::block_radix_rank; using keys_exchange_type = rocprim::block_exchange; @@ -95,7 +96,7 @@ __global__ __launch_bounds__(BlockSize) void rank_kernel(const T* const ite unsigned int ranks[ItemsPerThread]; rocprim::block_load_direct_blocked(lid, items_input + block_offset, keys); - if ROCPRIM_IF_CONSTEXPR(warp_striped) + if constexpr(warp_striped) { // block_radix_rank_match requires warp striped input and output. Instead of using // rocprim::block_load_direct_warp_striped though, we load directly and exchange the @@ -114,7 +115,7 @@ __global__ __launch_bounds__(BlockSize) void rank_kernel(const T* const ite block_rank_type().rank_keys(keys, ranks, storage.rank, start_bit, radix_bits); } - if ROCPRIM_IF_CONSTEXPR(warp_striped) + if constexpr(warp_striped) { // See the comment above. rocprim::syncthreads(); diff --git a/test/rocprim/test_block_radix_sort.kernels.hpp b/test/rocprim/test_block_radix_sort.kernels.hpp index aa18c690e..f0827ebab 100644 --- a/test/rocprim/test_block_radix_sort.kernels.hpp +++ b/test/rocprim/test_block_radix_sort.kernels.hpp @@ -37,7 +37,7 @@ #include #include #include -#include +#include #include #include @@ -84,26 +84,30 @@ __global__ __launch_bounds__(BlockSize) void sort_key_kernel(key_type* device key_type keys[ItemsPerThread]; rocprim::block_load_direct_blocked(lid, device_keys_output + block_offset, keys); - rocprim::block_radix_sort - bsort; - + using block_radix_sort = rocprim::block_radix_sort; + block_radix_sort bsort; test_utils::select_decomposer_t decomposer{}; + // Share LDS storage between all different sort invocations explicitely. + // This resolves local memory allocation exceeding limits when targeting SPIR-V. + __shared__ typename block_radix_sort::storage_type storage; + + // Sort differently depending on passed flags. if(to_striped) { if(descending) { - bsort.sort_desc_to_striped(keys, start_bit, end_bit, decomposer); + bsort.sort_desc_to_striped(keys, storage, start_bit, end_bit, decomposer); } else { - bsort.sort_to_striped(keys, start_bit, end_bit, decomposer); + bsort.sort_to_striped(keys, storage, start_bit, end_bit, decomposer); } rocprim::block_store_direct_striped(lid, device_keys_output + block_offset, keys); @@ -112,11 +116,11 @@ __global__ __launch_bounds__(BlockSize) void sort_key_kernel(key_type* device { if(descending) { - bsort.sort_desc(keys, start_bit, end_bit, decomposer); + bsort.sort_desc(keys, storage, start_bit, end_bit, decomposer); } else { - bsort.sort(keys, start_bit, end_bit, decomposer); + bsort.sort(keys, storage, start_bit, end_bit, decomposer); } rocprim::block_store_direct_blocked(lid, device_keys_output + block_offset, keys); @@ -144,19 +148,25 @@ __global__ __launch_bounds__(BlockSize) void sort_key_value_kernel(key_type* d rocprim::block_load_direct_blocked(lid, device_keys_output + block_offset, keys); rocprim::block_load_direct_blocked(lid, device_values_output + block_offset, values); - rocprim:: - block_radix_sort - bsort; + using block_radix_sort = rocprim:: + block_radix_sort; + block_radix_sort bsort; test_utils::select_decomposer_t decomposer{}; + + // Share LDS storage between all different sort invocations explicitely. + // This resolved local memory allocation exceeding limits when targeting SPIR-V. + __shared__ typename block_radix_sort::storage_type storage; + + // Sort differently depending on passed flags. if(to_striped) { if(descending) { - bsort.sort_desc_to_striped(keys, values, start_bit, end_bit, decomposer); + bsort.sort_desc_to_striped(keys, values, storage, start_bit, end_bit, decomposer); } else { - bsort.sort_to_striped(keys, values, start_bit, end_bit, decomposer); + bsort.sort_to_striped(keys, values, storage, start_bit, end_bit, decomposer); } rocprim::block_store_direct_striped(lid, device_keys_output + block_offset, keys); @@ -166,11 +176,11 @@ __global__ __launch_bounds__(BlockSize) void sort_key_value_kernel(key_type* d { if(descending) { - bsort.sort_desc(keys, values, start_bit, end_bit, decomposer); + bsort.sort_desc(keys, values, storage, start_bit, end_bit, decomposer); } else { - bsort.sort(keys, values, start_bit, end_bit, decomposer); + bsort.sort(keys, values, storage, start_bit, end_bit, decomposer); } rocprim::block_store_direct_blocked(lid, device_keys_output + block_offset, keys); diff --git a/test/rocprim/test_block_scan.cpp.in b/test/rocprim/test_block_scan.cpp.in index 981047ccb..19d1b0c60 100644 --- a/test/rocprim/test_block_scan.cpp.in +++ b/test/rocprim/test_block_scan.cpp.in @@ -40,24 +40,29 @@ struct RocprimBlockScanSingleValueTests; struct RocprimBlockScanInputArrayTests; -#cmakedefine ROCPRIM_TEST_SLICE @ROCPRIM_TEST_SLICE@ - -#if ROCPRIM_TEST_SLICE == 0 - -struct Integral; -#define suite_name_single RocprimBlockScanSingleValueTests -#define suite_name_array RocprimBlockScanInputArrayTests -#define block_params BlockParamsIntegral -#define name_suffix Integral - -#elif ROCPRIM_TEST_SLICE == 1 +#if !_CLANGD + #cmakedefine ROCPRIM_TEST_SLICE @ROCPRIM_TEST_SLICE@ +#endif -struct Floating; -#define suite_name_single RocprimBlockScanSingleValueTests -#define suite_name_array RocprimBlockScanInputArrayTests -#define block_params BlockExchParamsFloating -#define name_suffix Floating +#if ROCPRIM_TEST_SLICE == 0 || _CLANGD + struct Integral; + #define suite_name_single RocprimBlockScanSingleValueTests + #define suite_name_array RocprimBlockScanInputArrayTests + #define block_params BlockParamsIntegral + #define name_suffix Integral +#endif +#if ROCPRIM_TEST_SLICE == 1 || _CLANGD + struct Floating; + #define suite_name_single RocprimBlockScanSingleValueTests + #define suite_name_array RocprimBlockScanInputArrayTests + #define block_params BlockExchParamsFloating + #define name_suffix Floating #endif -#include "test_block_scan.hpp" +#if !_CLANGD + // When using clangd, the '.cpp.in' file is already included + // in the header. To prevent recursive includes, the header is + // not supposed to be added. + #include "test_block_scan.hpp" +#endif diff --git a/test/rocprim/test_block_scan.hpp b/test/rocprim/test_block_scan.hpp index 32ac8c51f..900455f77 100644 --- a/test/rocprim/test_block_scan.hpp +++ b/test/rocprim/test_block_scan.hpp @@ -20,6 +20,13 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. +#ifdef _CLANGD + // When using clangd, to allow the language server to function properly, + // some context of the template source is required to allow the language + // server to function properly. + #include "test_block_scan.cpp.in" +#endif + block_reduce_test_suite_type_def(suite_name_single, name_suffix) block_reduce_test_suite_type_def(suite_name_array, name_suffix) @@ -286,13 +293,18 @@ typed_test_def(suite_name_single, name_suffix, InclusiveScanReduceInitialValue) for(size_t i = 0; i < output.size() / block_size; i++) { acc_type accumulator(initial_value); + acc_type reduction = output[i * block_size]; for(size_t j = 0; j < block_size; j++) { - auto idx = i * block_size + j; + size_t idx = i * block_size + j; accumulator = binary_op_host(output[idx], accumulator); expected[idx] = static_cast(accumulator); + if(j > 0) + { + reduction = binary_op_host(output[idx], reduction); + } } - expected_reductions[i] = expected[(i + 1) * block_size - 1]; + expected_reductions[i] = reduction; } // Writing to device memory diff --git a/test/rocprim/test_block_scan.kernels.hpp b/test/rocprim/test_block_scan.kernels.hpp index 4eca2c418..886993479 100644 --- a/test/rocprim/test_block_scan.kernels.hpp +++ b/test/rocprim/test_block_scan.kernels.hpp @@ -29,6 +29,7 @@ #include "test_utils.hpp" #include "test_utils_assertions.hpp" #include "test_utils_data_generation.hpp" +#include "test_utils_types.hpp" #include #include diff --git a/test/rocprim/test_config_dispatch.cpp b/test/rocprim/test_config_dispatch.cpp index 6697bf10f..bb5e08063 100644 --- a/test/rocprim/test_config_dispatch.cpp +++ b/test/rocprim/test_config_dispatch.cpp @@ -62,6 +62,8 @@ TEST(RocprimConfigDispatchTests, StrEqualN) ASSERT_FALSE(prefix_equals("hasprefix", "hasp", 4)); } +#if !defined(ROCPRIM_EXPERIMENTAL_SPIRV) // This macro disables the config_dispatching + TEST(RocprimConfigDispatchTests, HostMatchesDevice) { const int device_id = test_common_utils::obtain_device_from_ctest(); @@ -101,6 +103,8 @@ TEST(RocprimConfigDispatchTests, ParseCommonArches) ASSERT_EQ(parse_gcn_arch("gfx90a:sramecc+:xnack-"), target_arch::gfx90a); } +#endif // ROCPRIM_EXPERIMENTAL_SPIRV + #ifndef _WIN32 TEST(RocprimConfigDispatchTests, DeviceIdFromStream) { diff --git a/test/rocprim/test_device_find_end.cpp b/test/rocprim/test_device_find_end.cpp index 268cbc8bc..328bb9246 100644 --- a/test/rocprim/test_device_find_end.cpp +++ b/test/rocprim/test_device_find_end.cpp @@ -41,7 +41,6 @@ #include #include #include -#include #include #include diff --git a/test/rocprim/test_device_histogram.cpp b/test/rocprim/test_device_histogram.cpp index b8ba77ce4..ee22d8986 100644 --- a/test/rocprim/test_device_histogram.cpp +++ b/test/rocprim/test_device_histogram.cpp @@ -38,7 +38,6 @@ #include #include #include -#include #include #include diff --git a/test/rocprim/test_device_merge.cpp b/test/rocprim/test_device_merge.cpp index d5a6ca088..a51560bd8 100644 --- a/test/rocprim/test_device_merge.cpp +++ b/test/rocprim/test_device_merge.cpp @@ -80,9 +80,12 @@ class RocprimDeviceMergeTests : public ::testing::Test using custom_int2 = common::custom_type; using custom_double2 = common::custom_type; +using custom_large = common::custom_huge_type<1024, long long>; using RocprimDeviceMergeTestsParams = ::testing::Types< DeviceMergeParams, + DeviceMergeParams, + DeviceMergeParams, DeviceMergeParams>, DeviceMergeParams, DeviceMergeParams, @@ -93,7 +96,8 @@ using RocprimDeviceMergeTestsParams = ::testing::Types< DeviceMergeParams>, DeviceMergeParams>, DeviceMergeParams, - DeviceMergeParams, true>>; + DeviceMergeParams, true>, + DeviceMergeParams>; // size1, size2 std::vector> get_sizes() @@ -143,6 +147,14 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) // hipMallocManaged() currently doesnt support zero byte allocation continue; } + + if((std::get<0>(sizes) + std::get<1>(sizes) >= 100000 + && sizeof(key_type) > sizeof(size_t) * 16)) + { + // Huge types are slow + continue; + } + SCOPED_TRACE( testing::Message() << "with sizes = {" << std::get<0>(sizes) << ", " << std::get<1>(sizes) << "}" @@ -279,6 +291,14 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) // hipMallocManaged() currently doesnt support zero byte allocation continue; } + + if((std::get<0>(sizes) + std::get<1>(sizes) >= 100000 + && sizeof(key_type) > sizeof(size_t) * 16)) + { + // Huge types are slow + continue; + } + SCOPED_TRACE( testing::Message() << "with sizes = {" << std::get<0>(sizes) << ", " << std::get<1>(sizes) << "}" @@ -423,7 +443,9 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) } if (TestFixture::use_graphs) + { HIP_CHECK(hipStreamDestroy(stream)); + } } template diff --git a/test/rocprim/test_device_nth_element.cpp b/test/rocprim/test_device_nth_element.cpp index 672d49e26..3a5dadf11 100644 --- a/test/rocprim/test_device_nth_element.cpp +++ b/test/rocprim/test_device_nth_element.cpp @@ -90,7 +90,6 @@ void inline compare_cpp_14(InputVector input, ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, sorted_input)); } -#if CPP17 template void inline compare_cpp_17(InputVector input, OutputVector output, @@ -122,7 +121,6 @@ void inline compare_cpp_17(InputVector input, ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, sorted_input)); } -#endif template void inline compare(InputVector input, @@ -131,12 +129,7 @@ void inline compare(InputVector input, CompareFunction compare_op) { compare_cpp_14(input, output, nth_element, compare_op); -#if CPP17 - // this comparison is only compiled and executed if c++17 is available compare_cpp_17(input, output, nth_element, compare_op); -#else - ROCPRIM_PRAGMA_MESSAGE("c++17 not available skips direct comparison with std::nth_element"); -#endif } // --------------------------------------------------------- diff --git a/test/rocprim/test_device_partial_sort.cpp b/test/rocprim/test_device_partial_sort.cpp index 1550513cc..a48d669ca 100644 --- a/test/rocprim/test_device_partial_sort.cpp +++ b/test/rocprim/test_device_partial_sort.cpp @@ -140,7 +140,6 @@ void inline compare_partial_sort_cpp_14(InputVector input, ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, sorted_input)); } -#if CPP17 template void inline compare_partial_sort_cpp_17(InputVector input, OutputVector output, @@ -169,7 +168,6 @@ void inline compare_partial_sort_cpp_17(InputVector input, // Check if the values are the same ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, sorted_input)); } -#endif template void inline compare_partial_sort(InputVector input, @@ -178,12 +176,7 @@ void inline compare_partial_sort(InputVector input, CompareFunction compare_op) { compare_partial_sort_cpp_14(input, output, middle, compare_op); -#if CPP17 - // this comparison is only compiled and executed if c++17 is available compare_partial_sort_cpp_17(input, output, middle, compare_op); -#else - ROCPRIM_PRAGMA_MESSAGE("c++17 not available skips direct comparison with std::partial_sort"); -#endif } TYPED_TEST_SUITE(RocprimDevicePartialSortTests, RocprimDevicePartialSortTestsParams); @@ -329,7 +322,6 @@ void inline compare_partial_sort_copy_cpp_14(InputVector input, ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected_output)); } -#if CPP17 template void inline compare_partial_sort_copy_cpp_17(InputVector input, OutputVector output, @@ -355,7 +347,6 @@ void inline compare_partial_sort_copy_cpp_17(InputVector input, // Check if the values are the same ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(sorted_output, output)); } -#endif template void inline compare_partial_sort_copy(InputVector input, @@ -365,13 +356,7 @@ void inline compare_partial_sort_copy(InputVector input, CompareFunction compare_op) { compare_partial_sort_copy_cpp_14(input, output, orignal_output, middle, compare_op); -#if CPP17 - // this comparison is only compiled and executed if c++17 is available compare_partial_sort_copy_cpp_17(input, output, orignal_output, middle, compare_op); -#else - ROCPRIM_PRAGMA_MESSAGE( - "c++17 not available skips direct comparison with std::partial_sort_copy"); -#endif } TYPED_TEST(RocprimDevicePartialSortTests, PartialSortCopy) diff --git a/test/rocprim/test_device_radix_sort.cpp.in b/test/rocprim/test_device_radix_sort.cpp.in index bc88b778b..fda36f10c 100644 --- a/test/rocprim/test_device_radix_sort.cpp.in +++ b/test/rocprim/test_device_radix_sort.cpp.in @@ -24,7 +24,6 @@ #include "../../common/utils_custom_type.hpp" -#include "test_utils_custom_float_traits_type.hpp" #include "test_utils_custom_float_type.hpp" #include "test_utils_custom_test_types.hpp" @@ -87,7 +86,6 @@ using __custom__uint64_t2 = common::custom_type; INSTANTIATE(params) INSTANTIATE(params) INSTANTIATE(params) - INSTANTIATE(params) INSTANTIATE(params>) // start_bit and end_bit diff --git a/test/rocprim/test_device_radix_sort.hpp b/test/rocprim/test_device_radix_sort.hpp index d75dc9999..23a466285 100644 --- a/test/rocprim/test_device_radix_sort.hpp +++ b/test/rocprim/test_device_radix_sort.hpp @@ -32,7 +32,6 @@ #include #include #include -#include #include #include "../../common/utils_device_ptr.hpp" @@ -41,7 +40,6 @@ #include "test_seed.hpp" #include "test_utils.hpp" #include "test_utils_assertions.hpp" -#include "test_utils_custom_float_traits_type.hpp" #include "test_utils_custom_float_type.hpp" #include "test_utils_custom_test_types.hpp" #include "test_utils_data_generation.hpp" @@ -117,8 +115,7 @@ auto generate_key_input(KeyIter keys_input, size_t size, engine_type& rng_engine // Working around custom_float_test_type, which is both a float and a common::custom_type template constexpr bool is_custom_not_float_test_type - = common::is_custom_type::value && !std::is_same::value - && !std::is_same::value; + = common::is_custom_type::value && !std::is_same::value; template auto invoke_sort_keys(void* d_temporary_storage, diff --git a/test/rocprim/test_device_reduce_by_key.cpp b/test/rocprim/test_device_reduce_by_key.cpp index ebce6659d..6c418c2b6 100644 --- a/test/rocprim/test_device_reduce_by_key.cpp +++ b/test/rocprim/test_device_reduce_by_key.cpp @@ -31,7 +31,6 @@ #include "test_seed.hpp" #include "test_utils.hpp" #include "test_utils_assertions.hpp" -#include "test_utils_custom_float_traits_type.hpp" #include "test_utils_custom_test_types.hpp" #include "test_utils_data_generation.hpp" #include "test_utils_hipgraphs.hpp" @@ -45,7 +44,7 @@ #include #include #include -#include +#include #include #include diff --git a/test/rocprim/test_device_scan.cpp b/test/rocprim/test_device_scan.cpp index efb428b76..fecf1b72c 100644 --- a/test/rocprim/test_device_scan.cpp +++ b/test/rocprim/test_device_scan.cpp @@ -25,7 +25,7 @@ #include "../../common/utils_custom_type.hpp" #include "../../common/utils_device_ptr.hpp" -// required test headers +// Required test headers #include "bounds_checking_iterator.hpp" #include "identity_iterator.hpp" #include "test_utils.hpp" @@ -34,16 +34,19 @@ #include "test_utils_data_generation.hpp" #include "test_utils_hipgraphs.hpp" -// required rocprim headers +// Required rocprim headers #include #include #include #include #include #include +#include +#include #include #include #include +#include #include #include #include @@ -98,7 +101,8 @@ template + bool Deterministic = false, + bool UseInitialValue = false> struct DeviceScanParams { using input_type = InputType; @@ -108,6 +112,7 @@ struct DeviceScanParams using config_helper = ConfigHelper; static constexpr bool use_graphs = UseGraphs; static constexpr bool deterministic = Deterministic; + static constexpr bool use_initial_value = UseInitialValue; }; template @@ -176,8 +181,9 @@ class RocprimDeviceScanTests : public ::testing::Test const bool debug_synchronous = false; static constexpr bool use_identity_iterator = Params::use_identity_iterator; using config_helper = typename Params::config_helper; - bool use_graphs = Params::use_graphs; - static constexpr bool deterministic = Params::deterministic; + bool use_graphs = Params::use_graphs; + static constexpr bool deterministic = Params::deterministic; + static constexpr bool use_initial_value = Params::use_initial_value; }; using RocprimDeviceScanTestsParams = ::testing::Types< @@ -247,25 +253,322 @@ using RocprimDeviceScanTestsParams = ::testing::Types< DeviceScanParams>, DeviceScanParams>, // With graphs - DeviceScanParams, false, default_config_helper, true>>; + DeviceScanParams, false, default_config_helper, true>, + // With initial values + DeviceScanParams, + false, + size_limit_config_helper<524288>, + false, + true, + true>, + DeviceScanParams, + false, + size_limit_config_helper<524288>, + false, + true, + true>, + DeviceScanParams, + false, + size_limit_config_helper<524288>, + false, + true, + true>, + DeviceScanParams, + false, + size_limit_config_helper<524288>, + false, + true, + true>, + DeviceScanParams, + false, + size_limit_config_helper<524288>, + false, + true, + true>, + DeviceScanParams, + false, + size_limit_config_helper<524288>, + false, + true, + true>, + DeviceScanParams, + false, + size_limit_config_helper<524288>, + false, + true, + true>, + DeviceScanParams, + false, + size_limit_config_helper<524288>, + false, + true, + true>>; -// use float for accumulation of bfloat16 and half inputs if operator is plus -template struct accum_type { - static constexpr bool is_low_precision = - std::is_same::value || - std::is_same::value; +// Use float for accumulation of bfloat16 and half inputs if operator is plus +template +struct accum_type +{ + static constexpr bool is_low_precision + = std::is_same::value + || std::is_same::value; static constexpr bool is_plus = test_utils::is_plus_operator::value; using type = typename std::conditional_t; }; TYPED_TEST_SUITE(RocprimDeviceScanTests, RocprimDeviceScanTestsParams); +TYPED_TEST(RocprimDeviceScanTests, LookBackScan) +{ + using T = typename TestFixture::input_type; + using U = typename TestFixture::output_type; + using scan_op_type = typename TestFixture::scan_op_type; + // If scan_op_type is rocprim::plus and input_type is bfloat16 or half, + // use float as device-side accumulator and double as host-side accumulator + using is_plus_op = test_utils::is_plus_operator; + using acc_type = typename accum_type::type; + using scan_state_type = rocprim::detail::lookback_scan_state; + using scan_state_with_sleep_type = rocprim::detail::lookback_scan_state; + + const bool deterministic = TestFixture::deterministic; + const bool use_initial_value = TestFixture::use_initial_value; + + using Config = typename TestFixture::config_helper::template type; + using config = rocprim::detail::wrapped_scan_config; + + hipStream_t stream = hipStreamDefault; + + rocprim::detail::target_arch target_arch; + HIP_CHECK(host_target_arch(stream, target_arch)); + const rocprim::detail::scan_config_params params + = rocprim::detail::dispatch_target_arch(target_arch); + + // For non-associative operations in inclusive scan + // intermediate results use the type of input iterator, then + // as all conversions in the tests are to more precise types, + // intermediate results use the same or more precise acc_type, + // all scan operations use the same acc_type, + // and all output types are the same acc_type, + // therefore the only source of error is precision of operation itself + constexpr float single_op_precision = is_plus_op::value ? test_utils::precision : 0; + + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(auto size : test_utils::get_sizes(seed_value)) + { + if(size == 0) + { + continue; + } + const unsigned int block_size = params.kernel_config.block_size; + const unsigned int items_per_thread = params.kernel_config.items_per_thread; + const auto items_per_block = block_size * items_per_thread; + + unsigned int number_of_blocks = (size + items_per_block - 1) / items_per_block; + + if(single_op_precision * size > 0.5) + { + std::cout << "Test is skipped from size " << size + << " on, potential error of summation is more than 0.5 of the result " + "with current or larger size" + << std::endl; + break; + } + hipStream_t stream = hipStreamDefault; + if(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + SCOPED_TRACE(testing::Message() << "with size = " << size); + + // Generate data + std::vector input = test_utils::get_random_data_wrapped(size, 1, 10, seed_value); + + common::device_ptr d_input(input); + common::device_ptr d_output(input.size()); + + // Scan function + scan_op_type scan_op; + + // Calculate expected results on host + std::vector expected(input.size()); + acc_type initial_value; + if(use_initial_value) + { + initial_value = test_utils::get_random_value(1, 10, seed_value); + test_utils::host_inclusive_scan(input.begin(), + input.end(), + expected.begin(), + scan_op, + initial_value); + } + else + { + test_utils::host_inclusive_scan(input.begin(), + input.end(), + expected.begin(), + scan_op); + } + SCOPED_TRACE(use_initial_value + ? (testing::Message() << "with initial_value = " << initial_value) + : (testing::Message() << "without initial_value")); + + auto input_iterator + = rocprim::make_transform_iterator(d_input.get(), + [](T in) { return static_cast(in); }); + + // Pointer to array with block_prefixes + acc_type* previous_last_element; + acc_type* new_last_element; + + rocprim::detail::temp_storage::layout layout{}; + HIP_CHECK(scan_state_type::get_temp_storage_layout(number_of_blocks, stream, layout)); + + size_t storage_size; + HIP_CHECK(scan_state_type::get_storage_size(number_of_blocks, stream, storage_size)); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(storage_size, 0); + + // Allocate temporary storage + common::device_ptr d_temp_storage(storage_size); + + scan_state_type scan_state{}; + HIP_CHECK(scan_state_type::create(scan_state, + d_temp_storage.get(), + number_of_blocks, + stream)); + scan_state_with_sleep_type scan_state_with_sleep{}; + HIP_CHECK(scan_state_with_sleep_type::create(scan_state_with_sleep, + d_temp_storage.get(), + number_of_blocks, + stream)); + + // Call the provided function with either scan_state or scan_state_with_sleep based on + // the value of use_sleep + bool use_sleep; + HIP_CHECK(rocprim::detail::is_sleep_scan_state_used(stream, use_sleep)) + auto with_scan_state = [use_sleep, scan_state, scan_state_with_sleep]( + auto&& func) mutable -> decltype(auto) + { + if(use_sleep) + { + return func(scan_state_with_sleep); + } + else + { + return func(scan_state); + } + }; + auto grid_size = (number_of_blocks + block_size - 1) / block_size; + with_scan_state( + [&](const auto scan_state) + { + rocprim::detail::init_lookback_scan_state_kernel<<>>(scan_state, + number_of_blocks); + }); + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + static constexpr bool Exclusive = false; + + grid_size = number_of_blocks; + + with_scan_state( + [&](const auto scan_state) + { + rocprim::detail::lookback_scan_kernel< + deterministic + ? rocprim::detail::lookback_scan_determinism::deterministic + : rocprim::detail::lookback_scan_determinism::nondeterministic, + Exclusive, + use_initial_value, + config, + decltype(input_iterator), + U*, + scan_op_type, + acc_type, + acc_type> + <<>>(input_iterator, + d_output.get(), + size, + initial_value, + scan_op, + scan_state, + number_of_blocks, + previous_last_element, + new_last_element, + false, + false); + }); + + if(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream, true, false); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + // Copy output to host + const auto output = d_output.load(); + + // Check if output values are as expected + ASSERT_NO_FATAL_FAILURE( + test_utils::assert_near(output, expected, single_op_precision * size)); + + if(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + +TYPED_TEST_SUITE(RocprimDeviceScanTests, RocprimDeviceScanTestsParams); + TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) { using T = typename TestFixture::input_type; using U = typename TestFixture::output_type; using scan_op_type = typename TestFixture::scan_op_type; - // if scan_op_type is rocprim::plus and input_type is bfloat16 or half, + // If scan_op_type is rocprim::plus and input_type is bfloat16 or half, // use float as device-side accumulator and double as host-side accumulator using acc_type = typename accum_type::type; const bool debug_synchronous = TestFixture::debug_synchronous; @@ -274,9 +577,10 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - - hipStream_t stream = 0; // default - if (TestFixture::use_graphs) + + // Default + hipStream_t stream = 0; + if(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -284,12 +588,12 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) common::device_ptr d_output(1); - test_utils::out_of_bounds_flag out_of_bounds; + test_utils::out_of_bounds_flag out_of_bounds; test_utils::bounds_checking_iterator d_checking_output(d_output.get(), out_of_bounds.device_pointer(), 0); - // scan function + // Scan function scan_op_type scan_op; auto input_iterator @@ -307,7 +611,7 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) stream, debug_synchronous)); - // allocate temporary storage + // Allocate temporary storage common::device_ptr d_temp_storage(temp_storage_size_bytes); test_utils::GraphHelper gHelper; @@ -336,7 +640,7 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) ASSERT_FALSE(out_of_bounds.get()); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -349,12 +653,12 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) using U = typename TestFixture::output_type; using scan_op_type = typename TestFixture::scan_op_type; - // if scan_op_type is rocprim::plus and input_type is bfloat16 or half, + // If scan_op_type is rocprim::plus and input_type is bfloat16 or half, // use float as device-side accumulator and double as host-side accumulator using is_plus_op = test_utils::is_plus_operator; using acc_type = typename accum_type::type; - // for non-associative operations in inclusive scan + // For non-associative operations in inclusive scan // intermediate results use the type of input iterator, then // as all conversions in the tests are to more precise types, // intermediate results use the same or more precise acc_type, @@ -373,7 +677,8 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); for(auto size : test_utils::get_sizes(seed_value)) @@ -386,8 +691,10 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) << std::endl; break; } - hipStream_t stream = 0; // default - if (TestFixture::use_graphs) + + // Default + hipStream_t stream = 0; + if(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -401,15 +708,31 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) common::device_ptr d_input(input); common::device_ptr d_output(input.size()); - // scan function + // Scan function scan_op_type scan_op; // Calculate expected results on host std::vector expected(input.size()); - test_utils::host_inclusive_scan( - input.begin(), input.end(), - expected.begin(), scan_op - ); + acc_type initial_value; + if(TestFixture::use_initial_value) + { + initial_value = test_utils::get_random_value(1, 10, seed_value); + test_utils::host_inclusive_scan(input.begin(), + input.end(), + expected.begin(), + scan_op, + initial_value); + } + else + { + test_utils::host_inclusive_scan(input.begin(), + input.end(), + expected.begin(), + scan_op); + } + SCOPED_TRACE(TestFixture::use_initial_value + ? (testing::Message() << "with initial_value = " << initial_value) + : (testing::Message() << "without initial_value")); auto input_iterator = rocprim::make_transform_iterator(d_input.get(), @@ -417,20 +740,36 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) // Get size of d_temp_storage size_t temp_storage_size_bytes; - HIP_CHECK((invoke_inclusive_scan( - nullptr, - temp_storage_size_bytes, - input_iterator, - test_utils::wrap_in_identity_iterator(d_output.get()), - input.size(), - scan_op, - stream, - TestFixture::debug_synchronous))); + if(TestFixture::use_initial_value) + { + HIP_CHECK((invoke_inclusive_scan( + nullptr, + temp_storage_size_bytes, + input_iterator, + test_utils::wrap_in_identity_iterator(d_output.get()), + initial_value, + input.size(), + scan_op, + stream, + TestFixture::debug_synchronous))); + } + else + { + HIP_CHECK((invoke_inclusive_scan( + nullptr, + temp_storage_size_bytes, + input_iterator, + test_utils::wrap_in_identity_iterator(d_output.get()), + input.size(), + scan_op, + stream, + TestFixture::debug_synchronous))); + } // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); - // allocate temporary storage + // Allocate temporary storage common::device_ptr d_temp_storage(temp_storage_size_bytes); test_utils::GraphHelper gHelper; @@ -440,15 +779,31 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) } // Run - HIP_CHECK((invoke_inclusive_scan( - d_temp_storage.get(), - temp_storage_size_bytes, - input_iterator, - test_utils::wrap_in_identity_iterator(d_output.get()), - input.size(), - scan_op, - stream, - TestFixture::debug_synchronous))); + if(TestFixture::use_initial_value) + { + HIP_CHECK((invoke_inclusive_scan( + d_temp_storage.get(), + temp_storage_size_bytes, + input_iterator, + test_utils::wrap_in_identity_iterator(d_output.get()), + initial_value, + input.size(), + scan_op, + stream, + TestFixture::debug_synchronous))); + } + else + { + HIP_CHECK((invoke_inclusive_scan( + d_temp_storage.get(), + temp_storage_size_bytes, + input_iterator, + test_utils::wrap_in_identity_iterator(d_output.get()), + input.size(), + scan_op, + stream, + TestFixture::debug_synchronous))); + } if(TestFixture::use_graphs) { @@ -471,7 +826,7 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) } } - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -486,12 +841,12 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) using U = typename TestFixture::output_type; using scan_op_type = typename TestFixture::scan_op_type; - // if scan_op_type is rocprim::plus and input_type is bfloat16 or half, + // If scan_op_type is rocprim::plus and input_type is bfloat16 or half, // use float as device-side accumulator and double as host-side accumulator using is_plus_op = test_utils::is_plus_operator; using acc_type = typename accum_type::type; - // for non-associative operations in exclusive scan + // For non-associative operations in exclusive scan // intermediate results use the type of initial value, then // as all conversions in the tests are to more precise types, // intermediate results use the same or more precise acc_type, @@ -512,7 +867,8 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); for(auto size : test_utils::get_sizes(seed_value)) @@ -525,8 +881,10 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) << std::endl; break; } - hipStream_t stream = 0; // default - if (TestFixture::use_graphs) + + // Default + hipStream_t stream = 0; + if(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -540,17 +898,17 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) common::device_ptr d_input(input); common::device_ptr d_output(input.size()); - // scan function + // Scan function scan_op_type scan_op; // Calculate expected results on host std::vector expected(input.size()); initial_value = test_utils::get_random_value(1, 10, seed_value); - test_utils::host_exclusive_scan( - input.begin(), input.end(), - initial_value, expected.begin(), - scan_op - ); + test_utils::host_exclusive_scan(input.begin(), + input.end(), + initial_value, + expected.begin(), + scan_op); auto input_iterator = rocprim::make_transform_iterator(d_input.get(), @@ -572,7 +930,7 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); - // allocate temporary storage + // Allocate temporary storage common::device_ptr d_temp_storage(temp_storage_size_bytes); test_utils::GraphHelper gHelper; @@ -614,7 +972,7 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) } } - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -627,11 +985,11 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) { // scan-by-key does not support output iterator with void value_type using T = typename TestFixture::input_type; - using K = unsigned int; // key type + using K = unsigned int; // Key type using U = typename TestFixture::output_type; using scan_op_type = typename TestFixture::scan_op_type; - // if scan_op_type is rocprim::plus and input_type is bfloat16 or half, + // If scan_op_type is rocprim::plus and input_type is bfloat16 or half, // use float as device-side accumulator and double as host-side accumulator using is_plus_op = test_utils::is_plus_operator; using acc_type = typename accum_type::type; @@ -648,7 +1006,8 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); for(auto size : test_utils::get_sizes(seed_value)) @@ -661,8 +1020,10 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) << std::endl; break; } - hipStream_t stream = 0; // default - if (TestFixture::use_graphs) + + // Default + hipStream_t stream = 0; + if(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -689,18 +1050,19 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) common::device_ptr d_keys(keys); common::device_ptr d_output(input.size()); - // scan function + // Scan function scan_op_type scan_op; - // key compare function + // Key compare function rocprim::equal_to keys_compare_op; // Calculate expected results on host std::vector expected(input.size()); - test_utils::host_inclusive_scan_by_key( - input.begin(), input.end(), keys.begin(), - expected.begin(), - scan_op, keys_compare_op - ); + test_utils::host_inclusive_scan_by_key(input.begin(), + input.end(), + keys.begin(), + expected.begin(), + scan_op, + keys_compare_op); auto input_iterator = rocprim::make_transform_iterator(d_input.get(), @@ -722,7 +1084,7 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); - // allocate temporary storage + // Allocate temporary storage common::device_ptr d_temp_storage(temp_storage_size_bytes); test_utils::GraphHelper gHelper; @@ -758,7 +1120,7 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) ASSERT_NO_FATAL_FAILURE( test_utils::assert_near(output, expected, single_op_precision * (size - 1))); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -771,11 +1133,11 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) { // scan-by-key does not support output iterator with void value_type using T = typename TestFixture::input_type; - using K = unsigned int; // key type + using K = unsigned int; // Key type using U = typename TestFixture::output_type; using scan_op_type = typename TestFixture::scan_op_type; - // if scan_op_type is rocprim::plus and input_type is bfloat16 or half, + // If scan_op_type is rocprim::plus and input_type is bfloat16 or half, // use float as device-side accumulator and double as host-side accumulator using is_plus_op = test_utils::is_plus_operator; using acc_type = typename accum_type::type; @@ -793,7 +1155,8 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); for(auto size : test_utils::get_sizes(seed_value)) @@ -806,8 +1169,10 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) << std::endl; break; } - hipStream_t stream = 0; // default - if (TestFixture::use_graphs) + + // Default + hipStream_t stream = 0; + if(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -835,19 +1200,21 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) common::device_ptr d_keys(keys); common::device_ptr d_output(input.size()); - // scan function + // Scan function scan_op_type scan_op; - // key compare function + // Key compare function rocprim::equal_to keys_compare_op; // Calculate expected results on host std::vector expected(input.size()); - test_utils::host_exclusive_scan_by_key( - input.begin(), input.end(), keys.begin(), - initial_value, expected.begin(), - scan_op, keys_compare_op - ); + test_utils::host_exclusive_scan_by_key(input.begin(), + input.end(), + keys.begin(), + initial_value, + expected.begin(), + scan_op, + keys_compare_op); auto input_iterator = rocprim::make_transform_iterator(d_input.get(), @@ -870,7 +1237,7 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); - // allocate temporary storage + // Allocate temporary storage common::device_ptr d_temp_storage(temp_storage_size_bytes); test_utils::GraphHelper gHelper; @@ -907,7 +1274,7 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) ASSERT_NO_FATAL_FAILURE( test_utils::assert_near(output, expected, single_op_precision * (size - 1))); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -916,23 +1283,28 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) } } -template -class single_index_iterator { +template +class single_index_iterator +{ private: - class conditional_discard_value { + class conditional_discard_value + { public: __host__ __device__ explicit conditional_discard_value(T* const value, bool keep) - : value_{value} - , keep_{keep} - { - } + : value_{value}, keep_{keep} + {} - __host__ __device__ conditional_discard_value& operator=(T value) { - if(keep_) { + __host__ __device__ + conditional_discard_value& + operator=(T value) + { + if(keep_) + { *value_ = value; } return *this; } + private: T* const value_; const bool keep_; @@ -947,17 +1319,17 @@ class single_index_iterator { using reference = conditional_discard_value; using pointer = conditional_discard_value*; using iterator_category = std::random_access_iterator_tag; - using difference_type = std::ptrdiff_t; + using difference_type = std::ptrdiff_t; __host__ __device__ single_index_iterator(T* value, size_t expected_index, size_t index = 0) - : value_{value} - , expected_index_{expected_index} - , index_{index} - { - } + : value_{value}, expected_index_{expected_index}, index_{index} + {} __host__ __device__ single_index_iterator(const single_index_iterator&) = default; - __host__ __device__ single_index_iterator& operator=(const single_index_iterator&) = default; + __host__ __device__ + single_index_iterator& + operator=(const single_index_iterator&) + = default; // clang-format off __host__ __device__ bool operator==(const single_index_iterator& rhs) const { return index_ == rhs.index_; } @@ -983,20 +1355,21 @@ class single_index_iterator { // clang-format on }; -template +template void testLargeIndicesInclusiveScan() { int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - using T = size_t; - using Iterator = typename rocprim::counting_iterator; - using OutputIterator = single_index_iterator; + using T = size_t; + using Iterator = typename rocprim::counting_iterator; + using OutputIterator = single_index_iterator; const bool debug_synchronous = false; - hipStream_t stream = 0; // default - if (UseGraphs) + // Default + hipStream_t stream = 0; + if(UseGraphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -1022,20 +1395,37 @@ void testLargeIndicesInclusiveScan() OutputIterator output_it{d_output.get(), size - 1}; // Get temporary array size + size_t initial_value = 0; size_t temp_storage_size_bytes; - HIP_CHECK(rocprim::inclusive_scan(nullptr, - temp_storage_size_bytes, - input_begin, - output_it, - size, - ::rocprim::plus(), - stream, - debug_synchronous)); + if constexpr(UseInitialValue) + { + initial_value = test_utils::get_random_value(0, 10000, seed_value); + HIP_CHECK(rocprim::inclusive_scan(nullptr, + temp_storage_size_bytes, + input_begin, + output_it, + initial_value, + size, + ::rocprim::plus(), + stream, + debug_synchronous)); + } + else + { + HIP_CHECK(rocprim::inclusive_scan(nullptr, + temp_storage_size_bytes, + input_begin, + output_it, + size, + ::rocprim::plus(), + stream, + debug_synchronous)); + } // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); - // allocate temporary storage + // Allocate temporary storage common::device_ptr d_temp_storage(temp_storage_size_bytes); test_utils::GraphHelper gHelper; @@ -1045,14 +1435,32 @@ void testLargeIndicesInclusiveScan() } // Run - HIP_CHECK(rocprim::inclusive_scan(d_temp_storage.get(), - temp_storage_size_bytes, - input_begin, - output_it, - size, - ::rocprim::plus(), - stream, - debug_synchronous)); + if constexpr(UseInitialValue) + { + HIP_CHECK(rocprim::inclusive_scan(d_temp_storage.get(), + temp_storage_size_bytes, + input_begin, + output_it, + initial_value, + size, + ::rocprim::plus(), + stream, + debug_synchronous)); + } + else + { + HIP_CHECK(rocprim::inclusive_scan(d_temp_storage.get(), + temp_storage_size_bytes, + input_begin, + output_it, + size, + ::rocprim::plus(), + stream, + debug_synchronous)); + } + SCOPED_TRACE(UseInitialValue + ? (testing::Message() << "with initial_value = " << initial_value) + : (testing::Message() << "without initial_value")); if(UseGraphs) { @@ -1069,8 +1477,10 @@ void testLargeIndicesInclusiveScan() // The division is not integer division but either (size) or (2n + size - 1) has to be even. const T multiplicand_1 = size; const T multiplicand_2 = 2 * (*input_begin) + size - 1; - const T expected_output = (multiplicand_1 % 2 == 0) ? multiplicand_1 / 2 * multiplicand_2 - : multiplicand_1 * (multiplicand_2 / 2); + const T expected_output + = ((multiplicand_1 % 2 == 0) ? multiplicand_1 / 2 * multiplicand_2 + : multiplicand_1 * (multiplicand_2 / 2)) + + initial_value; ASSERT_EQ(output, expected_output); @@ -1097,6 +1507,16 @@ TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScanWithGraphs) testLargeIndicesInclusiveScan(); } +TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScanWithInitialValue) +{ + testLargeIndicesInclusiveScan(); +} + +TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScanWithInitialValueAndGraphs) +{ + testLargeIndicesInclusiveScan(); +} + template void testLargeIndicesExclusiveScan() { @@ -1104,13 +1524,14 @@ void testLargeIndicesExclusiveScan() SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - using T = size_t; - using Iterator = typename rocprim::counting_iterator; - using OutputIterator = single_index_iterator; + using T = size_t; + using Iterator = typename rocprim::counting_iterator; + using OutputIterator = single_index_iterator; const bool debug_synchronous = false; - hipStream_t stream = 0; // default - if (UseGraphs) + // Default + hipStream_t stream = 0; + if(UseGraphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -1152,7 +1573,7 @@ void testLargeIndicesExclusiveScan() // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); - // allocate temporary storage + // Allocate temporary storage common::device_ptr d_temp_storage(temp_storage_size_bytes); test_utils::GraphHelper gHelper; @@ -1189,7 +1610,7 @@ void testLargeIndicesExclusiveScan() const T multiplicand_2 = 2 * (*input_begin) + size - 2; const T product = (multiplicand_1 % 2 == 0) ? multiplicand_1 / 2 * multiplicand_2 - : multiplicand_1 * (multiplicand_2 / 2); + : multiplicand_1 * (multiplicand_2 / 2); const T expected_output = initial_value + product; @@ -1230,64 +1651,87 @@ class check_run_iterator using reference = CheckValue; using pointer = CheckValue*; using iterator_category = std::random_access_iterator_tag; - using difference_type = std::ptrdiff_t; + using difference_type = std::ptrdiff_t; - ROCPRIM_HOST_DEVICE - check_run_iterator(const args_t args) : current_index_(0), args_(args) {} + ROCPRIM_HOST_DEVICE check_run_iterator(const args_t args) : current_index_(0), args_(args) {} - ROCPRIM_HOST_DEVICE bool operator==(const check_run_iterator& rhs) const + ROCPRIM_HOST_DEVICE + bool operator==(const check_run_iterator& rhs) const { return current_index_ == rhs.current_index_; } - ROCPRIM_HOST_DEVICE bool operator!=(const check_run_iterator& rhs) const + ROCPRIM_HOST_DEVICE + bool operator!=(const check_run_iterator& rhs) const { return !(*this == rhs); } - ROCPRIM_HOST_DEVICE reference operator*() + ROCPRIM_HOST_DEVICE + reference + operator*() { return value_type{current_index_, args_}; } - ROCPRIM_HOST_DEVICE reference operator[](const difference_type distance) const + ROCPRIM_HOST_DEVICE + reference + operator[](const difference_type distance) const { return *(*this + distance); } - ROCPRIM_HOST_DEVICE check_run_iterator& operator+=(const difference_type rhs) + ROCPRIM_HOST_DEVICE + check_run_iterator& + operator+=(const difference_type rhs) { current_index_ += rhs; return *this; } - ROCPRIM_HOST_DEVICE check_run_iterator& operator-=(const difference_type rhs) + ROCPRIM_HOST_DEVICE + check_run_iterator& + operator-=(const difference_type rhs) { current_index_ -= rhs; return *this; } - ROCPRIM_HOST_DEVICE difference_type operator-(const check_run_iterator& rhs) const + ROCPRIM_HOST_DEVICE + difference_type + operator-(const check_run_iterator& rhs) const { return current_index_ - rhs.current_index_; } - ROCPRIM_HOST_DEVICE check_run_iterator operator+(const difference_type rhs) const + ROCPRIM_HOST_DEVICE + check_run_iterator + operator+(const difference_type rhs) const { return check_run_iterator(*this) += rhs; } - ROCPRIM_HOST_DEVICE check_run_iterator operator-(const difference_type rhs) const + ROCPRIM_HOST_DEVICE + check_run_iterator + operator-(const difference_type rhs) const { return check_run_iterator(*this) -= rhs; } - ROCPRIM_HOST_DEVICE check_run_iterator& operator++() + ROCPRIM_HOST_DEVICE + check_run_iterator& + operator++() { ++current_index_; return *this; } - ROCPRIM_HOST_DEVICE check_run_iterator& operator--() + ROCPRIM_HOST_DEVICE + check_run_iterator& + operator--() { --current_index_; return *this; } - ROCPRIM_HOST_DEVICE check_run_iterator operator++(int) + ROCPRIM_HOST_DEVICE + check_run_iterator + operator++(int) { return ++check_run_iterator{*this}; } - ROCPRIM_HOST_DEVICE check_run_iterator operator--(int) + ROCPRIM_HOST_DEVICE + check_run_iterator + operator--(int) { return --check_run_iterator{*this}; } @@ -1307,7 +1751,8 @@ struct check_value_inclusive rocprim::tuple args_; // run_length, incorrect flag ROCPRIM_HOST_DEVICE - size_t operator=(const size_t value) + size_t + operator=(const size_t value) { const size_t run_start = current_index_ - (current_index_ % rocprim::get<0>(args_)); const size_t index_in_run = current_index_ - run_start + 1; @@ -1331,7 +1776,8 @@ struct check_value_exclusive args_; // run_length, initial_value, incorrect flag ROCPRIM_HOST_DEVICE - size_t operator=(const size_t value) + size_t + operator=(const size_t value) { const size_t run_start = current_index_ - (current_index_ % rocprim::get<0>(args_)); const size_t index_in_run = current_index_ - run_start; @@ -1359,9 +1805,9 @@ void large_indices_scan_by_key_test(ScanByKeyFun scan_by_key_fun) SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - constexpr bool debug_synchronous = false; - hipStream_t stream = 0; - if (UseGraphs) + constexpr bool debug_synchronous = false; + hipStream_t stream = 0; + if(UseGraphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -1428,7 +1874,7 @@ void large_indices_scan_by_key_test(ScanByKeyFun scan_by_key_fun) ASSERT_EQ(0, incorrect_flag); - if (UseGraphs) + if(UseGraphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -1450,7 +1896,7 @@ void testLargeIndicesInclusiveScanByKey() int /*seed_value*/) -> hipError_t { const check_run_inclusive_iterator output_it( - rocprim::make_tuple(run_length, d_incorrect_flag)); + rocprim::make_tuple(run_length, d_incorrect_flag)); return rocprim::inclusive_scan_by_key(d_temp_storage, temp_storage_size_bytes, @@ -1463,7 +1909,8 @@ void testLargeIndicesInclusiveScanByKey() stream, debug_synchronous); }; - large_indices_scan_by_key_test(inclusive_scan_by_key); + large_indices_scan_by_key_test( + inclusive_scan_by_key); } TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScanByKey) @@ -1492,7 +1939,7 @@ void testLargeIndicesExclusiveScanByKey() { const size_t initial_value = test_utils::get_random_value(0, 10000, seed_value); const check_run_exclusive_iterator output_it( - rocprim::make_tuple(run_length, initial_value, d_incorrect_flag)); + rocprim::make_tuple(run_length, initial_value, d_incorrect_flag)); return rocprim::exclusive_scan_by_key(d_temp_storage, temp_storage_size_bytes, keys_input, @@ -1505,7 +1952,8 @@ void testLargeIndicesExclusiveScanByKey() stream, debug_synchronous); }; - large_indices_scan_by_key_test(exclusive_scan_by_key); + large_indices_scan_by_key_test( + exclusive_scan_by_key); } TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScanByKey) @@ -1527,10 +1975,9 @@ using RocprimDeviceScanFutureTestsParams = ::testing::Types< DeviceScanParams>, DeviceScanParams, false, default_config_helper, true>>; -template +template class RocprimDeviceScanFutureTests : public RocprimDeviceScanTests -{ -}; +{}; TYPED_TEST_SUITE(RocprimDeviceScanFutureTests, RocprimDeviceScanFutureTestsParams); @@ -1575,8 +2022,9 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) break; } - hipStream_t stream = 0; // default - if (TestFixture::use_graphs) + // Default + hipStream_t stream = 0; + if(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -1594,15 +2042,19 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) common::device_ptr d_future_input(future_input); common::device_ptr d_initial_value(1); - // scan function + // Scan function scan_op_type scan_op; - const acc_type initial_value = std::accumulate(future_input.begin(), future_input.end(), T(0)); + const acc_type initial_value + = std::accumulate(future_input.begin(), future_input.end(), T(0)); // Calculate expected results on host std::vector expected(input.size()); - test_utils::host_exclusive_scan( - input.begin(), input.end(), initial_value, expected.begin(), scan_op); + test_utils::host_exclusive_scan(input.begin(), + input.end(), + initial_value, + expected.begin(), + scan_op); const auto future_iter = test_utils::wrap_in_identity_iterator( d_initial_value.get()); @@ -1639,7 +2091,7 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) rocprim::plus(), stream)); - // allocate temporary storage + // Allocate temporary storage, // we use a char pointer as we need to offset it common::device_ptr d_temp_storage(temp_storage_size_bytes + temp_storage_reduce); @@ -1682,7 +2134,7 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) // Check if output values are as expected ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, precision)); - if (TestFixture::use_graphs) + if(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); diff --git a/test/rocprim/test_device_search.cpp b/test/rocprim/test_device_search.cpp index 7f2e84952..979befae7 100644 --- a/test/rocprim/test_device_search.cpp +++ b/test/rocprim/test_device_search.cpp @@ -39,7 +39,6 @@ #include #include #include -#include #include #include diff --git a/test/rocprim/test_device_search_n.cpp b/test/rocprim/test_device_search_n.cpp index 04da41d7c..8f8b713cf 100644 --- a/test/rocprim/test_device_search_n.cpp +++ b/test/rocprim/test_device_search_n.cpp @@ -81,7 +81,7 @@ using custom_double2 = common::custom_type; using custom_int64_array = test_utils::custom_test_array_type; // Custom configs -using custom_config_0 = rocprim::search_n_config<256, 4>; +using custom_config_0 = rocprim::search_n_config<256, 4, 6>; using RocprimDeviceSearchNTestsParams = ::testing::Types< // Tests with default configuration @@ -151,16 +151,16 @@ TYPED_TEST(RocprimDeviceSearchNTests, RandomTest) std::fill(h_input.begin() + index, h_input.begin() + index + count, h_value); } - output_type h_output; - common::device_ptr d_input(h_input); - common::device_ptr d_value(std::vector({h_value})); - common::device_ptr d_output(1); + output_type h_output; + common::device_ptr d_input(h_input); + common::device_ptr d_value(std::vector({h_value})); + common::device_ptr d_output(1); SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -192,7 +192,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, RandomTest) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -206,7 +206,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, RandomTest) h_output = d_output.load()[0]; ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -243,17 +243,17 @@ TYPED_TEST(RocprimDeviceSearchNTests, MaxCount) = test_utils::get_random_value(0, limit_type::max(), ++seed_value); - std::vector h_input(size, h_value); - output_type h_output; - common::device_ptr d_input(h_input); - common::device_ptr d_value(std::vector({h_value})); - common::device_ptr d_output(1); + std::vector h_input(size, h_value); + output_type h_output; + common::device_ptr d_input(h_input); + common::device_ptr d_value(std::vector({h_value})); + common::device_ptr d_output(1); SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -285,7 +285,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, MaxCount) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -301,7 +301,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, MaxCount) ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -338,17 +338,17 @@ TYPED_TEST(RocprimDeviceSearchNTests, MinCount) = test_utils::get_random_value(0, limit_type::max(), ++seed_value); - std::vector h_input(size, h_value); - output_type h_output; - common::device_ptr d_input(h_input); - common::device_ptr d_value(std::vector({h_value})); - common::device_ptr d_output(1); + std::vector h_input(size, h_value); + output_type h_output; + common::device_ptr d_input(h_input); + common::device_ptr d_value(std::vector({h_value})); + common::device_ptr d_output(1); SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -362,8 +362,8 @@ TYPED_TEST(RocprimDeviceSearchNTests, MinCount) count, nullptr)); - common::device_ptr d_temp_storage(temp_storage_size); - test_utils::GraphHelper gHelper; + common::device_ptr d_temp_storage(temp_storage_size); + test_utils::GraphHelper gHelper; if(TestFixture::use_graphs) { gHelper.startStreamCapture(stream); @@ -379,7 +379,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, MinCount) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -395,7 +395,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, MinCount) ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -425,16 +425,16 @@ TYPED_TEST(RocprimDeviceSearchNTests, SmallCount) for(const auto size : test_utils::get_sizes(seed_value)) { - hipStream_t stream = 0; // default - size_t count = 0; - size_t temp_storage_size; - input_type h_value{1}; - input_type h_noise{0}; - std::vector h_input(size, h_noise); - output_type h_output; - common::device_ptr d_input(h_input); - common::device_ptr d_value(std::vector({h_value})); - common::device_ptr d_output(1); + hipStream_t stream = 0; // default + size_t count = 0; + size_t temp_storage_size; + input_type h_value{1}; + input_type h_noise{0}; + std::vector h_input(size, h_noise); + output_type h_output; + common::device_ptr d_input(h_input); + common::device_ptr d_value(std::vector({h_value})); + common::device_ptr d_output(1); if(size > 0 && size - 1 > 0) { @@ -447,7 +447,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, SmallCount) SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -461,8 +461,8 @@ TYPED_TEST(RocprimDeviceSearchNTests, SmallCount) count, nullptr)); - common::device_ptr d_temp_storage(temp_storage_size); - test_utils::GraphHelper gHelper; + common::device_ptr d_temp_storage(temp_storage_size); + test_utils::GraphHelper gHelper; if(TestFixture::use_graphs) { gHelper.startStreamCapture(stream); @@ -478,7 +478,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, SmallCount) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -494,7 +494,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, SmallCount) ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -531,16 +531,16 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromBegin) std::vector h_input(size); std::fill(h_input.begin(), h_input.begin() + (size - count), h_value); std::fill(h_input.begin() + count, h_input.end(), 0); - output_type h_output; - common::device_ptr d_input(h_input); - common::device_ptr d_value(std::vector({h_value})); - common::device_ptr d_output(1); + output_type h_output; + common::device_ptr d_input(h_input); + common::device_ptr d_value(std::vector({h_value})); + common::device_ptr d_output(1); SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -554,8 +554,8 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromBegin) count, nullptr)); - common::device_ptr d_temp_storage(temp_storage_size); - test_utils::GraphHelper gHelper; + common::device_ptr d_temp_storage(temp_storage_size); + test_utils::GraphHelper gHelper; if(TestFixture::use_graphs) { gHelper.startStreamCapture(stream); @@ -572,7 +572,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromBegin) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -588,7 +588,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromBegin) ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -625,16 +625,16 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromMiddle) std::vector h_input(size); std::fill(h_input.begin(), h_input.begin() + (size - count), 0); std::fill(h_input.begin() + count, h_input.end(), h_value); - output_type h_output; - common::device_ptr d_input(h_input); - common::device_ptr d_value(std::vector({h_value})); - common::device_ptr d_output(1); + output_type h_output; + common::device_ptr d_input(h_input); + common::device_ptr d_value(std::vector({h_value})); + common::device_ptr d_output(1); SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -648,8 +648,8 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromMiddle) count, nullptr)); - common::device_ptr d_temp_storage(temp_storage_size); - test_utils::GraphHelper gHelper; + common::device_ptr d_temp_storage(temp_storage_size); + test_utils::GraphHelper gHelper; if(TestFixture::use_graphs) { gHelper.startStreamCapture(stream); @@ -666,7 +666,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromMiddle) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -682,7 +682,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromMiddle) ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -719,16 +719,16 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromEnd) std::vector h_input(size); std::fill(h_input.begin(), h_input.begin() + (size - count), 0); std::fill(h_input.begin() + (size - count), h_input.end(), h_value); - output_type h_output; - common::device_ptr d_input(h_input); - common::device_ptr d_value(std::vector({h_value})); - common::device_ptr d_output(1); + output_type h_output; + common::device_ptr d_input(h_input); + common::device_ptr d_value(std::vector({h_value})); + common::device_ptr d_output(1); SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -742,8 +742,8 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromEnd) count, nullptr)); - common::device_ptr d_temp_storage(temp_storage_size); - test_utils::GraphHelper gHelper; + common::device_ptr d_temp_storage(temp_storage_size); + test_utils::GraphHelper gHelper; if(TestFixture::use_graphs) { gHelper.startStreamCapture(stream); @@ -760,7 +760,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromEnd) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -776,7 +776,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromEnd) ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -817,16 +817,16 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromEndButFail) { count += 2; } - output_type h_output; - common::device_ptr d_input(h_input); - common::device_ptr d_value(std::vector({h_value})); - common::device_ptr d_output(1); + output_type h_output; + common::device_ptr d_input(h_input); + common::device_ptr d_value(std::vector({h_value})); + common::device_ptr d_output(1); SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -840,8 +840,8 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromEndButFail) count, nullptr)); - common::device_ptr d_temp_storage(temp_storage_size); - test_utils::GraphHelper gHelper; + common::device_ptr d_temp_storage(temp_storage_size); + test_utils::GraphHelper gHelper; if(TestFixture::use_graphs) { gHelper.startStreamCapture(stream); @@ -858,7 +858,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromEndButFail) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -874,7 +874,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, StartFromEndButFail) ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -946,7 +946,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_1block) SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -960,8 +960,8 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_1block) count, nullptr)); - common::device_ptr d_temp_storage(temp_storage_size); - test_utils::GraphHelper gHelper; + common::device_ptr d_temp_storage(temp_storage_size); + test_utils::GraphHelper gHelper; if(TestFixture::use_graphs) { gHelper.startStreamCapture(stream); @@ -978,7 +978,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_1block) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -994,7 +994,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_1block) ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -1066,7 +1066,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_2block) SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -1080,8 +1080,8 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_2block) count, nullptr)); - common::device_ptr d_temp_storage(temp_storage_size); - test_utils::GraphHelper gHelper; + common::device_ptr d_temp_storage(temp_storage_size); + test_utils::GraphHelper gHelper; if(TestFixture::use_graphs) { gHelper.startStreamCapture(stream); @@ -1098,7 +1098,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_2block) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -1114,7 +1114,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_2block) ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -1186,7 +1186,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_3block) SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -1200,8 +1200,8 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_3block) count, nullptr)); - common::device_ptr d_temp_storage(temp_storage_size); - test_utils::GraphHelper gHelper; + common::device_ptr d_temp_storage(temp_storage_size); + test_utils::GraphHelper gHelper; if(TestFixture::use_graphs) { gHelper.startStreamCapture(stream); @@ -1218,7 +1218,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_3block) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -1234,7 +1234,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_3block) ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -1308,7 +1308,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult1) SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -1322,8 +1322,8 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult1) count, nullptr)); - common::device_ptr d_temp_storage(temp_storage_size); - test_utils::GraphHelper gHelper; + common::device_ptr d_temp_storage(temp_storage_size); + test_utils::GraphHelper gHelper; if(TestFixture::use_graphs) { gHelper.startStreamCapture(stream); @@ -1340,7 +1340,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult1) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -1356,7 +1356,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult1) ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); @@ -1423,7 +1423,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult2) SCOPED_TRACE(testing::Message() << "with count = " << count); SCOPED_TRACE(testing::Message() << "with value = " << h_value); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { // Default stream does not support hipGraph stream capture, so create one HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); @@ -1437,8 +1437,8 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult2) count, nullptr)); - common::device_ptr d_temp_storage(temp_storage_size); - test_utils::GraphHelper gHelper; + common::device_ptr d_temp_storage(temp_storage_size); + test_utils::GraphHelper gHelper; if(TestFixture::use_graphs) { gHelper.startStreamCapture(stream); @@ -1455,7 +1455,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult2) stream, debug_synchronous)); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.createAndLaunchGraph(stream); } @@ -1471,7 +1471,7 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult2) ASSERT_EQ(h_output, expected); - if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + if constexpr(TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); HIP_CHECK(hipStreamDestroy(stream)); diff --git a/test/rocprim/test_device_segmented_radix_sort.hpp b/test/rocprim/test_device_segmented_radix_sort.hpp index d34c1c1f3..9068614cf 100644 --- a/test/rocprim/test_device_segmented_radix_sort.hpp +++ b/test/rocprim/test_device_segmented_radix_sort.hpp @@ -70,14 +70,12 @@ struct params }; using config_default - = rocprim::segmented_radix_sort_config<4, //< long radix bits - 3, //< short radix bits + = rocprim::segmented_radix_sort_config<4, //< radix bits rocprim::kernel_config<256, //< sort block size, 4>>; //< items per thread using config_semi_custom - = rocprim::segmented_radix_sort_config<3, //< long radix bits - 2, //< short radix bits + = rocprim::segmented_radix_sort_config<3, //< radix bits rocprim::kernel_config<128, //< sort block size 4>, //< items per thread rocprim::WarpSortConfig<16, //< logical warp size small @@ -85,8 +83,7 @@ using config_semi_custom false>; //< enable unpartitioned sort using config_semi_custom_warp_config - = rocprim::segmented_radix_sort_config<3, //< long radix bits - 2, //< short radix bits + = rocprim::segmented_radix_sort_config<3, //< radix bits rocprim::kernel_config<128, //< sort block size 4>, //< items per thread rocprim::WarpSortConfig<16, //< logical warp size small @@ -96,8 +93,7 @@ using config_semi_custom_warp_config true>; //< enable unpartitioned sort using config_custom - = rocprim::segmented_radix_sort_config<3, //< long radix bits - 2, //< short radix bits + = rocprim::segmented_radix_sort_config<3, //< radix bits rocprim::kernel_config<128, //< sort block size 4>, //< items per thread rocprim::WarpSortConfig<16, //< logical warp size small diff --git a/test/rocprim/test_device_transform.cpp b/test/rocprim/test_device_transform.cpp index 47c4052de..514b46572 100644 --- a/test/rocprim/test_device_transform.cpp +++ b/test/rocprim/test_device_transform.cpp @@ -78,9 +78,10 @@ class RocprimDeviceTransformTests : public ::testing::Test static constexpr bool use_graphs = Params::use_graphs; }; -using custom_short2 = common::custom_type; -using custom_int2 = common::custom_type; -using custom_double2 = common::custom_type; +using custom_short2 = common::custom_type; +using custom_int2 = common::custom_type; +using custom_double2 = common::custom_type; +using custom_int64_array = test_utils::custom_test_array_type; using RocprimDeviceTransformTestsParams = ::testing::Types, @@ -92,7 +93,10 @@ using RocprimDeviceTransformTestsParams DeviceTransformParams, DeviceTransformParams, DeviceTransformParams, + DeviceTransformParams, + DeviceTransformParams, DeviceTransformParams, + DeviceTransformParams, DeviceTransformParams, DeviceTransformParams, DeviceTransformParams, @@ -304,6 +308,120 @@ TYPED_TEST(RocprimDeviceTransformTests, BinaryTransform) } } +template +OutputIterator transform_nary(Functor f, size_t size, OutputIterator out, Inputs... inputs) +{ + for(size_t i = 0; i < size; i++) + { + *out++ = f(*inputs++...); + } + + return out; +} + +template +struct ternary_transform +{ + __device__ __host__ + inline constexpr U + operator()(const T1& a, const T2& b, const T3& c) const + { + return a + b + c; + } +}; + +TYPED_TEST(RocprimDeviceTransformTests, TernaryTransform) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using T1 = typename TestFixture::input_type; + using T2 = typename TestFixture::input_type; + using U = typename TestFixture::output_type; + static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; + const bool debug_synchronous = TestFixture::debug_synchronous; + using Config = size_limit_config_t; + + for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(auto size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + if(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + SCOPED_TRACE(testing::Message() << "with size = " << size); + + // Generate data + std::vector input1 + = test_utils::get_random_data_wrapped(size, 1, 100, seed_value); + std::vector input2 + = test_utils::get_random_data_wrapped(size, 1, 100, seed_value); + std::vector input3 + = test_utils::get_random_data_wrapped(size, 1, 100, seed_value); + + common::device_ptr d_input1(input1); + common::device_ptr d_input2(input2); + common::device_ptr d_input3(input3); + common::device_ptr d_output(input1.size()); + + // Calculate expected results on host + std::vector expected(input1.size()); + + transform_nary(ternary_transform(), + input1.size(), + expected.begin(), + input1.begin(), + input2.begin(), + input3.begin()); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + // Run + HIP_CHECK(rocprim::transform( + rocprim::tuple(d_input1.get(), d_input2.get(), d_input3.get()), + test_utils::wrap_in_identity_iterator(d_output.get()), + input1.size(), + ternary_transform(), + stream, + debug_synchronous)); + + if(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream, true, false); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + // Copy output to host + const auto output = d_output.load(); + + // Check if output values are as expected + ASSERT_NO_FATAL_FAILURE( + test_utils::assert_near(output, expected, test_utils::precision)); + + if(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + template struct flag_expected_op_t { @@ -419,3 +537,65 @@ TEST(RocprimDeviceTransformTests, LargeIndicesWithGraphs) { testLargeIndices(); } + +TEST(RocprimDeviceTransformTests, UnalignedPointer) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using T = int; + + for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(auto size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + + SCOPED_TRACE(testing::Message() << "with size = " << size); + + // Generate data + std::vector input + = test_utils::get_random_data_wrapped(size + 2, 1, 100, seed_value); + + uint8_t* d_unaligned; + HIP_CHECK(hipMalloc(&d_unaligned, (size + 3) * sizeof(T))); + T* d_input = reinterpret_cast(d_unaligned + 1); + HIP_CHECK( + hipMemcpy(d_input, input.data(), (size + 2) * sizeof(T), hipMemcpyHostToDevice)); + + // Calculate expected results on host + std::vector expected(input.size()); + // First and last values should be unchanged. + expected[0] = input[0]; + expected[input.size() - 1] = input[input.size() - 1]; + std::transform(input.begin() + 1, + input.end() - 1, + expected.begin() + 1, + transform()); + + // Run + HIP_CHECK(rocprim::transform(d_input + 1, + d_input + 1, + input.size() - 2, + transform(), + stream)); + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + // Copy output to host + std::vector output(size + 2); + HIP_CHECK( + hipMemcpy(output.data(), d_input, (size + 2) * sizeof(T), hipMemcpyDeviceToHost)); + + // Check if output values are as expected + ASSERT_NO_FATAL_FAILURE( + test_utils::assert_near(output, expected, test_utils::precision)); + } + } +} diff --git a/test/rocprim/test_invoke_result.cpp b/test/rocprim/test_invoke_result.cpp index e6f163cc5..8039cb547 100644 --- a/test/rocprim/test_invoke_result.cpp +++ b/test/rocprim/test_invoke_result.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -74,7 +74,7 @@ TYPED_TEST(RocprimInvokeResultBinOpTests, HostInvokeResult) using binary_function = typename TestFixture::function; using expected_type = typename TestFixture::expected_type; - using resulting_type = rocprim::invoke_result_binary_op_t; + using resulting_type = ::rocprim::accumulator_t; // Compile and check on host static_assert(std::is_same::value, diff --git a/test/rocprim/test_linking_new_scan.hpp b/test/rocprim/test_linking_new_scan.hpp index cb2745103..6572a84cf 100644 --- a/test/rocprim/test_linking_new_scan.hpp +++ b/test/rocprim/test_linking_new_scan.hpp @@ -163,7 +163,9 @@ template::value_type>, - class AccType = typename std::iterator_traits::value_type> + class AccType + = ::rocprim::accumulator_t::value_type>> inline hipError_t inclusive_scan(void* temporary_storage, size_t& storage_size, InputIterator input, diff --git a/test/rocprim/test_radix_key_codec.cpp b/test/rocprim/test_radix_key_codec.cpp index beab09c54..74bf6cf77 100644 --- a/test/rocprim/test_radix_key_codec.cpp +++ b/test/rocprim/test_radix_key_codec.cpp @@ -33,7 +33,7 @@ #include "test_utils_sort_comparator.hpp" #include -#include +#include #include #include @@ -94,7 +94,7 @@ struct custom_key_decomposer TEST_P(RadixKeyCodecTest, ExtractDigit) { - using codec = rocprim::detail::radix_key_codec; + using codec = decltype(rocprim::traits::get().radix_key_codec()); const custom_key key{0xab, 0xcdef, 0x01}; const auto digit = codec::extract_digit(key, @@ -126,7 +126,7 @@ struct custom_key_decomposer_with_unused TEST_P(RadixKeyCodecUnusedTest, ExtractDigitUnused) { - using codec = rocprim::detail::radix_key_codec; + using codec = decltype(rocprim::traits::get().radix_key_codec()); const custom_key key{0xab, 0xcdef, 0x01}; const auto digit = codec::extract_digit(key, @@ -140,7 +140,7 @@ TEST_P(RadixKeyCodecUnusedTest, ExtractDigitUnused) TEST(RadixKeyCodecTest, ExtractCustomTestType) { using T = common::custom_type; - using codec_t = rocprim::detail::radix_key_codec; + using codec_t = decltype(rocprim::traits::get().radix_key_codec()); T value{12, 34}; @@ -328,7 +328,9 @@ TYPED_TEST_SUITE(TypedRadixKeyCodecTest, TypedRadixKeyCodecTestTypes); template void encode_then_decode_test(Key key, Decomposer decomposer) { - using codec_t = ::rocprim::radix_key_codec; + constexpr auto input_traits = ::rocprim::traits::get(); + constexpr auto codec = input_traits.template radix_key_codec(); + using codec_t = decltype(codec); using BitKey = typename codec_t::bit_key_type; BitKey bit_key = codec_t::encode(key, decomposer); @@ -353,7 +355,9 @@ void encode_then_extract_test(Key key, const unsigned int radix_bits, Decomposer decomposer) { - using codec_t = ::rocprim::radix_key_codec; + constexpr auto input_traits = ::rocprim::traits::get(); + constexpr auto codec = input_traits.template radix_key_codec(); + using codec_t = decltype(codec); using BitKey = typename codec_t::bit_key_type; BitKey bit_key = codec_t::encode(key, decomposer); @@ -382,7 +386,9 @@ void encode_then_extract_test_custom(Key key, const unsigned int radix_bits, Decomposer decomposer) { - using codec_t = ::rocprim::radix_key_codec; + constexpr auto input_traits = ::rocprim::traits::get(); + constexpr auto codec = input_traits.template radix_key_codec(); + using codec_t = decltype(codec); using BitKey = typename codec_t::bit_key_type; BitKey bit_key = codec_t::encode(key, decomposer); diff --git a/test/rocprim/test_type_traits_interface.cpp b/test/rocprim/test_type_traits_interface.cpp index cc24960ac..db214c575 100644 --- a/test/rocprim/test_type_traits_interface.cpp +++ b/test/rocprim/test_type_traits_interface.cpp @@ -23,7 +23,7 @@ #include "test_utils_custom_float_type.hpp" #include -#include +#include #include #include @@ -52,9 +52,6 @@ inline std::ostream& operator<<(std::ostream& stream, const custom_float_type& v return test_utils::operator<<(stream, value); } -struct float_bit_masked_type -{}; - // Custom type to model types like Eigen::half or Eigen::bfloat16, that wrap around floating point // types. struct custom_int_type @@ -135,15 +132,6 @@ struct rocprim::traits::define = rocprim::traits::integral_sign::values; }; -template<> -struct rocprim::detail::float_bit_mask -{ - static constexpr uint32_t sign_bit = 0x80000000; - static constexpr uint32_t exponent = 0x7F800000; - static constexpr uint32_t mantissa = 0x007FFFFF; - using bit_type = uint32_t; -}; - template class RocprimFloatingPointTests : public ::testing::Test { @@ -190,7 +178,7 @@ TYPED_TEST(RocprimFloatingPointTests, FloatingPoint) ROCPRIM_STATIC_ASSERT_EQ(input_traits.is_integral(), rocprim::is_integral::value); // cannot do static_assert because under c++ 14 there is no if constexpr - if ROCPRIM_IF_CONSTEXPR(rocprim::is_arithmetic::value) + if constexpr(rocprim::is_arithmetic::value) { // for c++ arithmetic types ASSERT_EQ(input_traits.is_compound(), rocprim::is_compound::value); ASSERT_EQ(input_traits.is_scalar(), rocprim::is_scalar::value); @@ -227,7 +215,7 @@ TYPED_TEST(RocprimIntegralTests, Integral) rocprim::is_floating_point::value); ROCPRIM_STATIC_ASSERT_NE(input_traits.is_signed(), input_traits.is_unsigned()); - if ROCPRIM_IF_CONSTEXPR(rocprim::is_arithmetic::value) + if constexpr(rocprim::is_arithmetic::value) { // for c++ arithmetic types ASSERT_EQ(input_traits.is_compound(), rocprim::is_compound::value); ASSERT_EQ(input_traits.is_scalar(), rocprim::is_scalar::value); @@ -247,29 +235,7 @@ TYPED_TEST(RocprimIntegralTests, Integral) } } -TEST(TraitsInterface, OldType) -{ - using input_type = type_traits_test::float_bit_masked_type; - using bit_mask = rocprim::detail::float_bit_mask; - - constexpr auto input_traits = rocprim::traits::get(); - ROCPRIM_STATIC_ASSERT_FALSE(input_traits.is_arithmetic()); - ROCPRIM_STATIC_ASSERT_FALSE(input_traits.is_fundamental()); - ROCPRIM_STATIC_ASSERT_TRUE(input_traits.is_compound()); - - ROCPRIM_STATIC_ASSERT_TRUE(input_traits.is_compound()); - ROCPRIM_STATIC_ASSERT_FALSE(input_traits.is_scalar()); - ROCPRIM_STATIC_ASSERT_FALSE(input_traits.is_floating_point()); - ROCPRIM_STATIC_ASSERT_FALSE(input_traits.is_integral()); - - constexpr auto float_bit_mask = input_traits.float_bit_mask(); - - ROCPRIM_STATIC_ASSERT_EQ(float_bit_mask.sign_bit, bit_mask::sign_bit); - ROCPRIM_STATIC_ASSERT_EQ(float_bit_mask.exponent, bit_mask::exponent); - ROCPRIM_STATIC_ASSERT_EQ(float_bit_mask.mantissa, bit_mask::mantissa); -} - -TEST(TraitsInterface, OtherType) +TEST(TraitsInterface, AllTypes) { struct TestT {}; diff --git a/test/rocprim/test_utils.hpp b/test/rocprim/test_utils.hpp index e9822c967..ab1a37b71 100644 --- a/test/rocprim/test_utils.hpp +++ b/test/rocprim/test_utils.hpp @@ -46,15 +46,16 @@ #include "test_utils_get_random_data.hpp" #include "test_utils_hipgraphs.hpp" -#include #include #include #include #include +#include #include #include #include +#include #include #include @@ -67,28 +68,28 @@ namespace test_utils // the results of _two_ sequences of operations with different order // For all other operations (i.e. integer arithmetics) default 0 is used template -static constexpr float precision = 0; +inline constexpr float precision = 0; template<> -static constexpr float precision = 2.0f / (1ll << 52); +inline constexpr float precision = 2.0f / (1ll << 52); template<> -static constexpr float precision = 2.0f / (1ll << 23); +inline constexpr float precision = 2.0f / (1ll << 23); template<> -static constexpr float precision = 2.0f / (1ll << 10); +inline constexpr float precision = 2.0f / (1ll << 10); template<> -static constexpr float precision = 2.0f / (1ll << 7); +inline constexpr float precision = 2.0f / (1ll << 7); template -static constexpr float precision = precision; +inline constexpr float precision = precision; template -static constexpr float precision> = precision; +inline constexpr float precision> = precision; template -static constexpr float precision> = precision; +inline constexpr float precision> = precision; template struct is_plus_operator : std::false_type @@ -179,7 +180,7 @@ constexpr std::vector host_reduce(InputIt first, InputIt last, rocprim::plus< return result; } // Calculate expected results on host - accumulator_type expected = accumulator_type(0); + accumulator_type expected = accumulator_type(0); rocprim::plus bin_op; for(int i = size - 1; i >= 0; --i) { @@ -201,7 +202,7 @@ template< = true> constexpr std::vector host_reduce(InputIt first, InputIt last, rocprim::plus op) { - using acc_type = T; + using acc_type = T; size_t size = std::distance(first, last); std::vector result(size); if(size == 0) @@ -219,15 +220,17 @@ constexpr std::vector host_reduce(InputIt first, InputIt last, rocprim::plus< } template -OutputIt host_inclusive_segmented_scan_headflags(InputIt first, InputIt last, FlagsIt flags, - OutputIt d_first, BinaryOperation op) +OutputIt host_inclusive_segmented_scan_headflags( + InputIt first, InputIt last, FlagsIt flags, OutputIt d_first, BinaryOperation op) { - if (first == last) return d_first; + if(first == last) + return d_first; acc_type sum = *first; - *d_first = sum; + *d_first = sum; - while (++first != last) { + while(++first != last) + { ++flags; sum = *flags ? acc_type(*first) : acc_type(op(sum, *first)); *++d_first = sum; @@ -236,15 +239,17 @@ OutputIt host_inclusive_segmented_scan_headflags(InputIt first, InputIt last, Fl } template -OutputIt host_exclusive_segmented_scan_headflags(InputIt first, InputIt last, FlagsIt flags, - OutputIt d_first, BinaryOperation op, acc_type init) +OutputIt host_exclusive_segmented_scan_headflags( + InputIt first, InputIt last, FlagsIt flags, OutputIt d_first, BinaryOperation op, acc_type init) { - if (first == last) return d_first; + if(first == last) + return d_first; acc_type sum = init; - *d_first = sum; + *d_first = sum; - while ((first+1) != last){ + while((first + 1) != last) + { ++flags; sum = *flags ? acc_type(init) : acc_type(op(sum, *first)); *++d_first = sum; @@ -253,55 +258,73 @@ OutputIt host_exclusive_segmented_scan_headflags(InputIt first, InputIt last, Fl return ++d_first; } -template -OutputIt host_inclusive_scan_impl(InputIt first, InputIt last, - OutputIt d_first, BinaryOperation op, acc_type) +template +OutputIt host_inclusive_scan_impl( + InputIt first, InputIt last, OutputIt d_first, BinaryOperation op, acc_type initial_value) { - if (first == last) return d_first; + if(first == last) + return d_first; - acc_type sum = *first; - *d_first = sum; + acc_type sum = UseInitialValue ? op(initial_value, *first) : static_cast(*first); + *d_first = sum; - while (++first != last) { - sum = op(sum, *first); + while(++first != last) + { + sum = op(sum, *first); *++d_first = sum; } return ++d_first; } template -OutputIt host_inclusive_scan(InputIt first, InputIt last, - OutputIt d_first, BinaryOperation op) +OutputIt host_inclusive_scan(InputIt first, InputIt last, OutputIt d_first, BinaryOperation op) { - using acc_type = rocprim::invoke_result_binary_op_t::value_type, BinaryOperation>; + using acc_type = ::rocprim::accumulator_t::value_type>; return host_inclusive_scan_impl(first, last, d_first, op, acc_type{}); } -template::value_type, rocprim::bfloat16>::value || - std::is_same::value_type, rocprim::half>::value || - std::is_same::value_type, float>::value - , bool> = true> -OutputIt host_inclusive_scan(InputIt first, InputIt last, - OutputIt d_first, rocprim::plus) +template +OutputIt host_inclusive_scan( + InputIt first, InputIt last, OutputIt d_first, BinaryOperation op, InitValueType initial_value) +{ + return host_inclusive_scan_impl(first, last, d_first, op, initial_value); +} + +template< + class InputIt, + class OutputIt, + class T, + std::enable_if_t< + std::is_same::value_type, rocprim::bfloat16>::value + || std::is_same::value_type, + rocprim::half>::value + || std::is_same::value_type, float>::value, + bool> + = true> +OutputIt host_inclusive_scan(InputIt first, InputIt last, OutputIt d_first, rocprim::plus) { using acc_type = double; return host_inclusive_scan_impl(first, last, d_first, rocprim::plus(), acc_type{}); } template -OutputIt host_exclusive_scan_impl(InputIt first, InputIt last, - T initial_value, OutputIt d_first, - BinaryOperation op, acc_type) +OutputIt host_exclusive_scan_impl( + InputIt first, InputIt last, T initial_value, OutputIt d_first, BinaryOperation op, acc_type) { - if (first == last) return d_first; + if(first == last) + return d_first; acc_type sum = initial_value; - *d_first = initial_value; + *d_first = initial_value; - while ((first+1) != last) + while((first + 1) != last) { - sum = op(sum, *first); + sum = op(sum, *first); *++d_first = sum; first++; } @@ -309,40 +332,62 @@ OutputIt host_exclusive_scan_impl(InputIt first, InputIt last, } template -OutputIt host_exclusive_scan(InputIt first, InputIt last, - T initial_value, OutputIt d_first, - BinaryOperation op) +OutputIt host_exclusive_scan( + InputIt first, InputIt last, T initial_value, OutputIt d_first, BinaryOperation op) { - using acc_type = rocprim::invoke_result_binary_op_t, BinaryOperation>; + using acc_type = ::rocprim::accumulator_t>; return host_exclusive_scan_impl(first, last, initial_value, d_first, op, acc_type{}); } -template::value_type, rocprim::bfloat16>::value || - std::is_same::value_type, rocprim::half>::value || - std::is_same::value_type, float>::value - , bool> = true> -OutputIt host_exclusive_scan(InputIt first, InputIt last, - T initial_value, OutputIt d_first, - rocprim::plus) +template< + class InputIt, + class T, + class OutputIt, + class U, + std::enable_if_t< + std::is_same::value_type, rocprim::bfloat16>::value + || std::is_same::value_type, + rocprim::half>::value + || std::is_same::value_type, float>::value, + bool> + = true> +OutputIt host_exclusive_scan( + InputIt first, InputIt last, T initial_value, OutputIt d_first, rocprim::plus) { using acc_type = double; - return host_exclusive_scan_impl(first, last, initial_value, d_first, rocprim::plus(), acc_type{}); + return host_exclusive_scan_impl(first, + last, + initial_value, + d_first, + rocprim::plus(), + acc_type{}); } -template -OutputIt host_exclusive_scan_by_key_impl(InputIt first, InputIt last, KeyIt k_first, - T initial_value, OutputIt d_first, - BinaryOperation op, KeyCompare key_compare_op, acc_type) -{ - if (first == last) return d_first; +template +OutputIt host_exclusive_scan_by_key_impl(InputIt first, + InputIt last, + KeyIt k_first, + T initial_value, + OutputIt d_first, + BinaryOperation op, + KeyCompare key_compare_op, + acc_type) +{ + if(first == last) + return d_first; acc_type sum = initial_value; - *d_first = initial_value; + *d_first = initial_value; - while ((first+1) != last) + while((first + 1) != last) { - if(key_compare_op(*k_first, *(k_first+1))) + if(key_compare_op(*k_first, *(k_first + 1))) { sum = op(sum, *first); } @@ -356,41 +401,87 @@ OutputIt host_exclusive_scan_by_key_impl(InputIt first, InputIt last, KeyIt k_fi } return ++d_first; } -template -OutputIt host_exclusive_scan_by_key(InputIt first, InputIt last, KeyIt k_first, - T initial_value, OutputIt d_first, - BinaryOperation op, KeyCompare key_compare_op) +template +OutputIt host_exclusive_scan_by_key(InputIt first, + InputIt last, + KeyIt k_first, + T initial_value, + OutputIt d_first, + BinaryOperation op, + KeyCompare key_compare_op) { using acc_type = typename std::iterator_traits::value_type; - return host_exclusive_scan_by_key_impl(first, last, k_first, initial_value, d_first, op, key_compare_op, acc_type{}); + return host_exclusive_scan_by_key_impl(first, + last, + k_first, + initial_value, + d_first, + op, + key_compare_op, + acc_type{}); } -template::value_type, rocprim::bfloat16>::value || - std::is_same::value_type, rocprim::half>::value || - std::is_same::value_type, float>::value - , bool> = true> -OutputIt host_exclusive_scan_by_key(InputIt first, InputIt last, KeyIt k_first, - T initial_value, OutputIt d_first, - rocprim::plus, KeyCompare key_compare_op) +template< + class InputIt, + class KeyIt, + class T, + class OutputIt, + class U, + class KeyCompare, + std::enable_if_t< + std::is_same::value_type, rocprim::bfloat16>::value + || std::is_same::value_type, + rocprim::half>::value + || std::is_same::value_type, float>::value, + bool> + = true> +OutputIt host_exclusive_scan_by_key(InputIt first, + InputIt last, + KeyIt k_first, + T initial_value, + OutputIt d_first, + rocprim::plus, + KeyCompare key_compare_op) { using acc_type = double; - return host_exclusive_scan_by_key_impl(first, last, k_first, initial_value, d_first, rocprim::plus(), key_compare_op, acc_type{}); + return host_exclusive_scan_by_key_impl(first, + last, + k_first, + initial_value, + d_first, + rocprim::plus(), + key_compare_op, + acc_type{}); } -template -OutputIt host_inclusive_scan_by_key_impl(InputIt first, InputIt last, KeyIt k_first, - OutputIt d_first, - BinaryOperation op, KeyCompare key_compare_op, acc_type) -{ - if (first == last) return d_first; +template +OutputIt host_inclusive_scan_by_key_impl(InputIt first, + InputIt last, + KeyIt k_first, + OutputIt d_first, + BinaryOperation op, + KeyCompare key_compare_op, + acc_type) +{ + if(first == last) + return d_first; acc_type sum = *first; - *d_first = sum; + *d_first = sum; - while (++first != last) + while(++first != last) { - if(key_compare_op(*k_first, *(k_first+1))) + if(key_compare_op(*k_first, *(k_first + 1))) { sum = op(sum, *first); } @@ -404,29 +495,54 @@ OutputIt host_inclusive_scan_by_key_impl(InputIt first, InputIt last, KeyIt k_fi return ++d_first; } template -OutputIt host_inclusive_scan_by_key(InputIt first, InputIt last, KeyIt k_first, - OutputIt d_first, - BinaryOperation op, KeyCompare key_compare_op) +OutputIt host_inclusive_scan_by_key(InputIt first, + InputIt last, + KeyIt k_first, + OutputIt d_first, + BinaryOperation op, + KeyCompare key_compare_op) { using acc_type = typename std::iterator_traits::value_type; - return host_inclusive_scan_by_key_impl(first, last, k_first, d_first, op, key_compare_op, acc_type{}); + return host_inclusive_scan_by_key_impl(first, + last, + k_first, + d_first, + op, + key_compare_op, + acc_type{}); } -template::value_type, rocprim::bfloat16>::value || - std::is_same::value_type, rocprim::half>::value || - std::is_same::value_type, float>::value - , bool> = true> -OutputIt host_inclusive_scan_by_key(InputIt first, InputIt last, KeyIt k_first, +template< + class InputIt, + class KeyIt, + class OutputIt, + class U, + class KeyCompare, + std::enable_if_t< + std::is_same::value_type, rocprim::bfloat16>::value + || std::is_same::value_type, + rocprim::half>::value + || std::is_same::value_type, float>::value, + bool> + = true> +OutputIt host_inclusive_scan_by_key(InputIt first, + InputIt last, + KeyIt k_first, OutputIt d_first, - rocprim::plus, KeyCompare key_compare_op) + rocprim::plus, + KeyCompare key_compare_op) { using acc_type = double; - return host_inclusive_scan_by_key_impl(first, last, k_first, d_first, rocprim::plus(), key_compare_op, acc_type{}); + return host_inclusive_scan_by_key_impl(first, + last, + k_first, + d_first, + rocprim::plus(), + key_compare_op, + acc_type{}); } -inline -size_t get_max_block_size() +inline size_t get_max_block_size() { int max_threads_blocks{}; @@ -439,7 +555,8 @@ template void iota(ForwardIt first, ForwardIt last, T value) { using value_type = typename std::iterator_traits::value_type; - while(first != last) { + while(first != last) + { *first++ = static_cast(value); ++value; } @@ -513,6 +630,6 @@ inline auto wrap_in_const(T* ptr) -> typename std::enable_if_t return ptr; } -} // end test_utils namespace +} // namespace test_utils #endif // TEST_TEST_UTILS_HPP_ diff --git a/test/rocprim/test_utils_assertions.hpp b/test/rocprim/test_utils_assertions.hpp index 916022a4e..2872eb094 100644 --- a/test/rocprim/test_utils_assertions.hpp +++ b/test/rocprim/test_utils_assertions.hpp @@ -29,7 +29,7 @@ #include "test_utils_bfloat16.hpp" #include "test_utils_custom_test_types.hpp" -#include +#include #include #include diff --git a/test/rocprim/test_utils_custom_float_traits_type.hpp b/test/rocprim/test_utils_custom_float_traits_type.hpp deleted file mode 100644 index 5cd3d9e0f..000000000 --- a/test/rocprim/test_utils_custom_float_traits_type.hpp +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#ifndef ROCPRIM_TEST_UTILS_CUSTOM_FLOAT_TRAITS_TYPE_HPP_ -#define ROCPRIM_TEST_UTILS_CUSTOM_FLOAT_TRAITS_TYPE_HPP_ - -#include "../../common/utils_custom_type.hpp" -#include "test_utils_custom_test_types.hpp" - -// For radix_key_codec -#include - -#include -#include - -#include - -namespace test_utils -{ -// Custom type to model types like Eigen::half or Eigen::bfloat16, that wrap around floating point -// types. -struct custom_float_traits_type -{ - using value_type = float; - float x; - - // Constructor for the data generation utilities, simply ignore the second number - ROCPRIM_HOST_DEVICE custom_float_traits_type(float val, float /*ignored*/) : x{val} - {} - - ROCPRIM_HOST_DEVICE custom_float_traits_type(float val) : x{val} {} - - custom_float_traits_type() = default; - - ROCPRIM_HOST_DEVICE - custom_float_traits_type - operator+(const custom_float_traits_type& other) const - { - return custom_float_traits_type(x + other.x); - } - - ROCPRIM_HOST_DEVICE - custom_float_traits_type - operator-(const custom_float_traits_type& other) const - { - return custom_float_traits_type(x - other.x); - } - - ROCPRIM_HOST_DEVICE - bool operator<(const custom_float_traits_type& other) const - { - return x < other.x; - } - - ROCPRIM_HOST_DEVICE - bool operator>(const custom_float_traits_type& other) const - { - return x > other.x; - } - - ROCPRIM_HOST_DEVICE - bool operator==(const custom_float_traits_type& other) const - { - return x == other.x; - } - - ROCPRIM_HOST_DEVICE - bool operator!=(const custom_float_traits_type& other) const - { - return !(*this == other); - } -}; - -inline bool signbit(const custom_float_traits_type& val) -{ - return std::signbit(val.x); -} - -inline std::ostream& operator<<(std::ostream& stream, const custom_float_traits_type& value) -{ - stream << "[" << value.x << "]"; - return stream; -} - -template<> -struct inner_type -{ - using type = custom_float_traits_type::value_type; -}; - -} // namespace test_utils - -namespace common -{ -template<> -struct is_custom_type : std::true_type -{}; -} // namespace common - -template<> -struct ::rocprim::traits::define -{ - using is_arithmetic = ::rocprim::traits::is_arithmetic::values; - using number_format = ::rocprim::traits::number_format::values< - ::rocprim::traits::number_format::kind::floating_point_type>; - using float_bit_mask - = ::rocprim::traits::float_bit_mask::values; -}; - -template<> -struct ::rocprim::detail::radix_key_codec_base - : ::rocprim::detail::radix_key_codec_floating -{}; - -#endif //ROCPRIM_TEST_UTILS_CUSTOM_FLOAT_TYPE_HPP_ diff --git a/test/rocprim/test_utils_custom_float_type.hpp b/test/rocprim/test_utils_custom_float_type.hpp index d186d3310..b14fd32f3 100644 --- a/test/rocprim/test_utils_custom_float_type.hpp +++ b/test/rocprim/test_utils_custom_float_type.hpp @@ -23,13 +23,12 @@ #include "test_utils_custom_test_types.hpp" -// For radix_key_codec -#include - -#include -#include +#include +#include +#include #include +#include namespace test_utils { @@ -110,31 +109,14 @@ struct is_custom_type : std::true_type // because this is something that is unavoidable in some cases we should provide clear customization // points instead of hacks like these. // Nonetheless until that adding a test for this pattern should reduce accidental breakages -namespace rocprim -{ - -namespace detail -{ - template<> -struct float_bit_mask +struct rocprim::traits::define { - static constexpr uint32_t sign_bit = 0x80000000; - static constexpr uint32_t exponent = 0x7F800000; - static constexpr uint32_t mantissa = 0x007FFFFF; - using bit_type = uint32_t; + using is_arithmetic = rocprim::traits::is_arithmetic::values; + using number_format + = rocprim::traits::number_format::values; + using float_bit_mask + = rocprim::traits::float_bit_mask::values; }; -template<> -struct radix_key_codec_base - : radix_key_codec_floating -{}; - -static_assert(!is_floating_point::value, - "custom_float_type must not be rocprim::is_floating_point, " - "since that is how downstream libraries use it."); - -} // namespace detail -} // namespace rocprim - #endif //ROCPRIM_TEST_UTILS_CUSTOM_FLOAT_TYPE_HPP_ diff --git a/test/rocprim/test_utils_data_generation.hpp b/test/rocprim/test_utils_data_generation.hpp index 895e59fc8..b6356a040 100644 --- a/test/rocprim/test_utils_data_generation.hpp +++ b/test/rocprim/test_utils_data_generation.hpp @@ -28,13 +28,11 @@ #include "../common_test_header.hpp" #include "test_seed.hpp" -#include "test_utils_custom_float_traits_type.hpp" #include "test_utils_custom_float_type.hpp" #include "test_utils_custom_test_types.hpp" #include #include -#include #include #include @@ -60,11 +58,6 @@ struct numeric_limits_custom_test_type : public numeric_limits and custom_test_array_type classes -template<> -struct numeric_limits - : detail::numeric_limits_custom_test_type -{}; - template<> struct numeric_limits : detail::numeric_limits_custom_test_type diff --git a/test/rocprim/test_utils_data_generation_with_rocrand.hpp b/test/rocprim/test_utils_data_generation_with_rocrand.hpp index 3c47ba3a2..b8d0698c1 100644 --- a/test/rocprim/test_utils_data_generation_with_rocrand.hpp +++ b/test/rocprim/test_utils_data_generation_with_rocrand.hpp @@ -41,35 +41,32 @@ namespace test_utils_with_rocrand template inline __device__ -auto generate_casting(T* output, StateT& state, U min, V max) +auto generate_casting(T* output, StateT& state, U min, V max, unsigned int global_id) -> std::enable_if_t::value> { - const unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; - - output[tid] - = static_cast((static_cast(rocrand(&state)) / static_cast((UINT_MAX))) + output[global_id] + = static_cast((static_cast(rocrand(&state)) / static_cast(UINT_MAX)) * (static_cast(max) - static_cast(min)) + static_cast(min)); } template inline __device__ -auto generate_casting(T* output, StateT& state, U min, V max) +auto generate_casting(T* output, StateT& state, U min, V max, unsigned int global_id) -> std::enable_if_t::value> { - const unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; - float f_value = (static_cast(rocrand(&state)) / static_cast((UINT_MAX))) + float f_value = (static_cast(rocrand(&state)) / static_cast(UINT_MAX)) * (static_cast(max) - static_cast(min)) + static_cast(min); - if ROCPRIM_IF_CONSTEXPR(std::is_same::value) + if constexpr(std::is_same::value) { - output[tid] = static_cast(__float2half_rn(f_value)); + output[global_id] = static_cast(__float2half_rn(f_value)); } else { - output[tid] = f_value; + output[global_id] = f_value; } } @@ -78,48 +75,45 @@ __global__ void generate_random_kernel( T* output, U min, V max, const unsigned long long seed = 0, const unsigned long long offset = 0) { - const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_id = blockIdx.x * blockDim.x + threadIdx.x; StateT state; const unsigned int subsequence = flat_id; rocrand_init(seed, subsequence, offset, &state); - - generate_casting(output, state, min, max); + generate_casting(output, state, min, max, flat_id); } template inline auto - generate_random_data_n(OutputIter it, size_t size, U min, V max, unsigned long long seed_value) + generate_random_data_n(OutputIter& it, size_t size, U min, V max, unsigned long long seed_value) { - if(size == 0) - return it; + return it.begin() + size; - using T = typename std::iterator_traits::value_type; + using T = typename OutputIter::value_type; // Allocate device memory common::device_ptr d_random_data(size); using state_t = rocrand_state_xorwow; - constexpr int threadsPerBlock = 1024; + constexpr int threadsPerBlock = 512; int blocksPerGrid = (size + threadsPerBlock - 1) / threadsPerBlock; generate_random_kernel <<>>(d_random_data.get(), min, max, seed_value, 0); HIP_CHECK(hipGetLastError()); - // Copy generated data from device to host memory - HIP_CHECK(hipMemcpy(&(*it), d_random_data.get(), size * sizeof(T), hipMemcpyDeviceToHost)); + HIP_CHECK(hipMemcpy(it.data(), d_random_data.get(), size * sizeof(T), hipMemcpyDeviceToHost)); - return it + size; + return it.begin() + size; } template std::vector get_random_data(size_t size, U min, V max, unsigned long long seed_value) { std::vector data(size); - generate_random_data_n(data.begin(), size, min, max, seed_value); + generate_random_data_n(data, size, min, max, seed_value); return data; } diff --git a/test/rocprim/test_utils_sort_comparator.hpp b/test/rocprim/test_utils_sort_comparator.hpp index 3ba1ac195..376e8b848 100644 --- a/test/rocprim/test_utils_sort_comparator.hpp +++ b/test/rocprim/test_utils_sort_comparator.hpp @@ -25,14 +25,12 @@ #include "../../common/utils_custom_type.hpp" -#include "test_utils_custom_float_traits_type.hpp" #include "test_utils_custom_float_type.hpp" #include "test_utils_custom_test_types.hpp" #include #include #include -#include #include #include @@ -99,8 +97,7 @@ template // that we must counter here. - && !std::is_same::value - && !std::is_same::value, + && !std::is_same::value, int> = 0> ROCPRIM_HOST_DEVICE @@ -139,8 +136,7 @@ template // that we must counter here. - && !std::is_same::value - && !std::is_same::value, + && !std::is_same::value, int> = 0> ROCPRIM_HOST_DEVICE @@ -177,17 +173,6 @@ auto to_bits(const Key key) -> typename rocprim::get_unsigned_bits_type::un { return to_bits(key.x); } - -template::value, int> = 0> -ROCPRIM_HOST_DEVICE -auto to_bits(const Key key) -> typename rocprim::get_unsigned_bits_type::unsigned_type -{ - return to_bits(key.x); -} - } // namespace detail template diff --git a/test/rocprim/test_warp_exchange.cpp b/test/rocprim/test_warp_exchange.cpp index d0d2567c3..842606f36 100644 --- a/test/rocprim/test_warp_exchange.cpp +++ b/test/rocprim/test_warp_exchange.cpp @@ -42,12 +42,15 @@ #include #include -template +template struct Params { using type = T; static constexpr unsigned int items_per_thread = ItemsPerThread; - static constexpr unsigned int warp_size = WarpSize; + static constexpr unsigned int warp_size = VirtualWaveSize; using exchange_op = ExchangeOp; }; @@ -137,7 +140,7 @@ auto warp_exchange_test(T* d_input, T* d_output) -> std::enable_if_t> { using warp_exchange_type = ::rocprim::warp_exchange; - constexpr unsigned int num_warps = ::rocprim::arch::wavefront::min_size() / LogicalWarpSize; + constexpr unsigned int num_warps = ::rocprim::arch::wavefront::max_size() / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; @@ -167,7 +170,7 @@ auto warp_exchange_test_not_inplace(T* d_input, T* d_output) -> std::enable_if_t> { using warp_exchange_type = ::rocprim::warp_exchange; - constexpr unsigned int num_warps = ::rocprim::arch::wavefront::min_size() / LogicalWarpSize; + constexpr unsigned int num_warps = ::rocprim::arch::wavefront::max_size() / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; @@ -337,7 +340,7 @@ auto warp_exchange_scatter_test(T* d_input, T* d_output, OffsetT* d_ranks) { using warp_exchange_type = ::rocprim::warp_exchange; - constexpr unsigned int num_warps = ::rocprim::arch::wavefront::min_size() / LogicalWarpSize; + constexpr unsigned int num_warps = ::rocprim::arch::wavefront::max_size() / LogicalWarpSize; ROCPRIM_SHARED_MEMORY typename warp_exchange_type::storage_type storage[num_warps]; T thread_data[ItemsPerThread]; diff --git a/test/rocprim/test_warp_load.cpp b/test/rocprim/test_warp_load.cpp index 2dd6d9935..72f005653 100644 --- a/test/rocprim/test_warp_load.cpp +++ b/test/rocprim/test_warp_load.cpp @@ -41,17 +41,15 @@ #include #include -template< - class T, - unsigned int ItemsPerThread, - unsigned int WarpSize, - ::rocprim::warp_load_method Method -> +template struct Params { using type = T; static constexpr unsigned int items_per_thread = ItemsPerThread; - static constexpr unsigned int warp_size = WarpSize; + static constexpr unsigned int warp_size = VirtualWaveSize; static constexpr ::rocprim::warp_load_method method = Method; }; diff --git a/test/rocprim/test_warp_reduce.hpp b/test/rocprim/test_warp_reduce.hpp index bb1746442..3dfe26a61 100644 --- a/test/rocprim/test_warp_reduce.hpp +++ b/test/rocprim/test_warp_reduce.hpp @@ -39,11 +39,11 @@ test_suite_type_def(suite_name, name_suffix) -typed_test_suite_def(RocprimWarpReduceTests, name_suffix, warp_params); + typed_test_suite_def(RocprimWarpReduceTests, name_suffix, warp_params); typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSum) { - // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() + // logical warp size for warp primitive, execution warp size is always rocprim::warp_size() using T = typename TestFixture::params::type; // for bfloat16 and half we use double for host-side accumulation using binary_op_type_host = typename test_utils::select_plus_operator_host::type; @@ -61,36 +61,41 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSum) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; static constexpr unsigned int grid_size = 4; - const size_t size = block_size * grid_size; + const size_t size = block_size * grid_size; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data @@ -105,7 +110,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSum) for(size_t j = 0; j < logical_warp_size; j++) { auto idx = i * logical_warp_size + j; - value = binary_op_host(input[idx], value); + value = binary_op_host(input[idx], value); } expected[i] = static_cast(value); } @@ -114,7 +119,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSum) common::device_ptr device_output(output.size()); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_reduce_sum_kernel), @@ -125,7 +130,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSum) device_input.get(), device_output.get()); } - else if (current_device_warp_size == ws64) + else if(current_device_warp_size == ws64) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_reduce_sum_kernel), @@ -153,9 +158,9 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSum) SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() + // logical warp size for warp primitive, execution warp size is always rocprim::warp_size() using T = typename TestFixture::params::type; - // for bfloat16 and half we use double for host-side accumulation + // for bfloat16 and half we use double for host-side accumulation using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; @@ -167,36 +172,41 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSum) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; static constexpr unsigned int grid_size = 4; - const size_t size = block_size * grid_size; + const size_t size = block_size * grid_size; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data @@ -211,11 +221,11 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSum) for(size_t j = 0; j < logical_warp_size; j++) { auto idx = i * logical_warp_size + j; - value = binary_op_host(input[idx], value); + value = binary_op_host(input[idx], value); } - for (size_t j = 0; j < logical_warp_size; j++) + for(size_t j = 0; j < logical_warp_size; j++) { - auto idx = i * logical_warp_size + j; + auto idx = i * logical_warp_size + j; expected[idx] = static_cast(value); } } @@ -224,7 +234,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSum) common::device_ptr device_output(output.size()); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_allreduce_sum_kernel), @@ -235,7 +245,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSum) device_input.get(), device_output.get()); } - else if (current_device_warp_size == ws64) + else if(current_device_warp_size == ws64) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_allreduce_sum_kernel), @@ -263,9 +273,9 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSumValid) SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() + // logical warp size for warp primitive, execution warp size is always rocprim::warp_size() using T = typename TestFixture::params::type; - // for bfloat16 and half we use double for host-side accumulation + // for bfloat16 and half we use double for host-side accumulation using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; @@ -277,37 +287,42 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSumValid) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; static constexpr unsigned int grid_size = 4; - const size_t size = block_size * grid_size; - const size_t valid = logical_warp_size - 1; + const size_t size = block_size * grid_size; + const size_t valid = logical_warp_size - 1; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data @@ -322,7 +337,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSumValid) for(size_t j = 0; j < valid; j++) { auto idx = i * logical_warp_size + j; - value = binary_op_host(input[idx], value); + value = binary_op_host(input[idx], value); } expected[i] = static_cast(value); } @@ -331,7 +346,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSumValid) common::device_ptr device_output(output.size()); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_reduce_sum_kernel), @@ -343,7 +358,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSumValid) device_output.get(), valid); } - else if (current_device_warp_size == ws64) + else if(current_device_warp_size == ws64) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_reduce_sum_kernel), @@ -364,7 +379,6 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceSumValid) test_utils::assert_near(output, expected, test_utils::precision * logical_warp_size); } - } typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSumValid) @@ -373,9 +387,9 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSumValid) SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() + // logical warp size for warp primitive, execution warp size is always rocprim::warp_size() using T = typename TestFixture::params::type; - // for bfloat16 and half we use double for host-side accumulation + // for bfloat16 and half we use double for host-side accumulation using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; @@ -387,37 +401,42 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSumValid) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; static constexpr unsigned int grid_size = 4; - const size_t size = block_size * grid_size; - const size_t valid = logical_warp_size - 1; + const size_t size = block_size * grid_size; + const size_t valid = logical_warp_size - 1; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data @@ -432,11 +451,11 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSumValid) for(size_t j = 0; j < valid; j++) { auto idx = i * logical_warp_size + j; - value = binary_op_host(input[idx], value); + value = binary_op_host(input[idx], value); } - for (size_t j = 0; j < logical_warp_size; j++) + for(size_t j = 0; j < logical_warp_size; j++) { - auto idx = i * logical_warp_size + j; + auto idx = i * logical_warp_size + j; expected[idx] = static_cast(value); } } @@ -445,7 +464,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSumValid) common::device_ptr device_output(output.size()); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_allreduce_sum_kernel), @@ -457,7 +476,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSumValid) device_output.get(), valid); } - else if (current_device_warp_size == ws64) + else if(current_device_warp_size == ws64) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_allreduce_sum_kernel), @@ -478,7 +497,6 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, AllReduceSumValid) test_utils::assert_near(output, expected, test_utils::precision * valid); } - } typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceCustomStruct) @@ -491,7 +509,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceCustomStruct) using T = common::custom_type; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() + // logical warp size for warp primitive, execution warp size is always rocprim::warp_size() static constexpr size_t logical_warp_size = TestFixture::params::warp_size; // The different warp sizes @@ -499,36 +517,41 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceCustomStruct) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; static constexpr unsigned int grid_size = 4; - const size_t size = block_size * grid_size; + const size_t size = block_size * grid_size; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data @@ -564,7 +587,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceCustomStruct) common::device_ptr device_output(output.size()); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_reduce_sum_kernel), @@ -575,7 +598,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceCustomStruct) device_input.get(), device_output.get()); } - else if (current_device_warp_size == ws64) + else if(current_device_warp_size == ws64) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_reduce_sum_kernel), @@ -597,7 +620,6 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, ReduceCustomStruct) expected, test_utils::precision * logical_warp_size); } - } typed_test_def(RocprimWarpReduceTests, name_suffix, HeadSegmentedReduceSum) @@ -606,14 +628,14 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, HeadSegmentedReduceSum) SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() + // logical warp size for warp primitive, execution warp size is always rocprim::warp_size() using T = typename TestFixture::params::type; - // for bfloat16 and half we use double for host-side accumulation + // for bfloat16 and half we use double for host-side accumulation using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using flag_type = unsigned char; + using flag_type = unsigned char; static constexpr size_t logical_warp_size = TestFixture::params::warp_size; // The different warp sizes @@ -621,42 +643,48 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, HeadSegmentedReduceSum) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; static constexpr unsigned int grid_size = 4; - const size_t size = block_size * grid_size; + const size_t size = block_size * grid_size; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data std::vector input = test_utils::get_random_data_wrapped(size, 1, 10, seed_value); - std::vector flags = test_utils::get_random_data01(size, 0.25f, seed_value); - for(size_t i = 0; i < flags.size(); i+= logical_warp_size) + std::vector flags + = test_utils::get_random_data01(size, 0.25f, seed_value); + for(size_t i = 0; i < flags.size(); i += logical_warp_size) { flags[i] = 1; } @@ -668,15 +696,15 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, HeadSegmentedReduceSum) // Calculate expected results on host std::vector expected(output.size()); - size_t segment_head_index = 0; - acc_type reduction(input[0]); + size_t segment_head_index = 0; + acc_type reduction(input[0]); for(size_t i = 0; i < output.size(); i++) { - if(i%logical_warp_size == 0 || flags[i]) + if(i % logical_warp_size == 0 || flags[i]) { expected[segment_head_index] = static_cast(reduction); - segment_head_index = i; - reduction = input[i]; + segment_head_index = i; + reduction = input[i]; } else { @@ -686,7 +714,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, HeadSegmentedReduceSum) expected[segment_head_index] = static_cast(reduction); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME(head_segmented_warp_reduce_kernel * logical_warp_size); } - } typed_test_def(RocprimWarpReduceTests, name_suffix, TailSegmentedReduceSum) @@ -746,14 +773,14 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, TailSegmentedReduceSum) SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - // logical warp side for warp primitive, execution warp size is always rocprim::warp_size() + // logical warp size for warp primitive, execution warp size is always rocprim::warp_size() using T = typename TestFixture::params::type; - // for bfloat16 and half we use double for host-side accumulation + // for bfloat16 and half we use double for host-side accumulation using binary_op_type_host = typename test_utils::select_plus_operator_host::type; binary_op_type_host binary_op_host; using acc_type = typename test_utils::select_plus_operator_host::acc_type; - using flag_type = unsigned char; + using flag_type = unsigned char; static constexpr size_t logical_warp_size = TestFixture::params::warp_size; // The different warp sizes @@ -761,42 +788,48 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, TailSegmentedReduceSum) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; static constexpr unsigned int grid_size = 4; - const size_t size = block_size * grid_size; + const size_t size = block_size * grid_size; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data std::vector input = test_utils::get_random_data_wrapped(size, 1, 10, seed_value); - std::vector flags = test_utils::get_random_data01(size, 0.25f, seed_value); - for(size_t i = logical_warp_size - 1; i < flags.size(); i+= logical_warp_size) + std::vector flags + = test_utils::get_random_data01(size, 0.25f, seed_value); + for(size_t i = logical_warp_size - 1; i < flags.size(); i += logical_warp_size) { flags[i] = 1; } @@ -807,10 +840,10 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, TailSegmentedReduceSum) common::device_ptr device_output(output.size()); // Calculate expected results on host - std::vector expected(output.size()); + std::vector expected(output.size()); std::vector segment_indexes; - size_t segment_index = 0; - acc_type reduction; + size_t segment_index = 0; + acc_type reduction; for(size_t i = 0; i < output.size(); i++) { // single value segments @@ -822,8 +855,8 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, TailSegmentedReduceSum) else { segment_index = i; - reduction = input[i]; - auto next = i + 1; + reduction = input[i]; + auto next = i + 1; while(next < output.size() && !flags[next]) { reduction = binary_op_host(input[next], reduction); @@ -837,7 +870,7 @@ typed_test_def(RocprimWarpReduceTests, name_suffix, TailSegmentedReduceSum) } // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME(tail_segmented_warp_reduce_kernel expected_segment(segment_indexes.size()); for(size_t i = 0; i < segment_indexes.size(); i++) { - auto index = segment_indexes[i]; - output_segment[i] = output[index]; + auto index = segment_indexes[i]; + output_segment[i] = output[index]; expected_segment[i] = expected[index]; } test_utils::assert_near(output_segment, expected_segment, test_utils::precision * logical_warp_size); } - } diff --git a/test/rocprim/test_warp_reduce.kernels.hpp b/test/rocprim/test_warp_reduce.kernels.hpp index e1a8ee4b5..521d27980 100644 --- a/test/rocprim/test_warp_reduce.kernels.hpp +++ b/test/rocprim/test_warp_reduce.kernels.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2019-2025 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -22,119 +22,131 @@ #define TEST_WARP_REDUCE_KERNELS_HPP_ template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_reduce_sum_kernel(T* device_input, T* device_output) { - static constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + static constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - T value = device_input[index]; + T value = device_input[index]; - using wreduce_t = rocprim::warp_reduce; - __shared__ typename wreduce_t::storage_type storage[warps_no]; - wreduce_t().reduce(value, value, storage[warp_id]); + using wreduce_t = rocprim::warp_reduce; + __shared__ typename wreduce_t::storage_type storage[warps_no]; + wreduce_t().reduce(value, value, storage[warp_id]); - if(threadIdx.x%LogicalWarpSize == 0) - { - device_output[index/LogicalWarpSize] = value; + if(threadIdx.x % LogicalWarpSize == 0) + { + device_output[index / LogicalWarpSize] = value; + } } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_allreduce_sum_kernel(T* device_input, T* device_output) { - static constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + static constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - T value = device_input[index]; + T value = device_input[index]; - using wreduce_t = rocprim::warp_reduce; - __shared__ typename wreduce_t::storage_type storage[warps_no]; - wreduce_t().reduce(value, value, storage[warp_id]); + using wreduce_t = rocprim::warp_reduce; + __shared__ typename wreduce_t::storage_type storage[warps_no]; + wreduce_t().reduce(value, value, storage[warp_id]); - device_output[index] = value; + device_output[index] = value; + } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_reduce_sum_kernel(T* device_input, T* device_output, size_t valid) { - static constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + static constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - T value = device_input[index]; + T value = device_input[index]; - using wreduce_t = rocprim::warp_reduce; - __shared__ typename wreduce_t::storage_type storage[warps_no]; - wreduce_t().reduce(value, value, valid, storage[warp_id]); + using wreduce_t = rocprim::warp_reduce; + __shared__ typename wreduce_t::storage_type storage[warps_no]; + wreduce_t().reduce(value, value, valid, storage[warp_id]); - if(threadIdx.x%LogicalWarpSize == 0) - { - device_output[index/LogicalWarpSize] = value; + if(threadIdx.x % LogicalWarpSize == 0) + { + device_output[index / LogicalWarpSize] = value; + } } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_allreduce_sum_kernel(T* device_input, T* device_output, size_t valid) { - constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - T value = device_input[index]; + T value = device_input[index]; - using wreduce_t = rocprim::warp_reduce; - __shared__ typename wreduce_t::storage_type storage[warps_no]; - wreduce_t().reduce(value, value, valid, storage[warp_id]); + using wreduce_t = rocprim::warp_reduce; + __shared__ typename wreduce_t::storage_type storage[warps_no]; + wreduce_t().reduce(value, value, valid, storage[warp_id]); - device_output[index] = value; + device_output[index] = value; + } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void head_segmented_warp_reduce_kernel(T* input, Flag* flags, T* output) { - constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - T value = input[index]; - auto flag = flags[index]; + T value = input[index]; + auto flag = flags[index]; - using wreduce_t = rocprim::warp_reduce; - __shared__ typename wreduce_t::storage_type storage[warps_no]; - wreduce_t().head_segmented_reduce(value, value, flag, storage[warp_id]); + using wreduce_t = rocprim::warp_reduce; + __shared__ typename wreduce_t::storage_type storage[warps_no]; + wreduce_t().head_segmented_reduce(value, value, flag, storage[warp_id]); - output[index] = value; + output[index] = value; + } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void tail_segmented_warp_reduce_kernel(T* input, Flag* flags, T* output) { - constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - T value = input[index]; - auto flag = flags[index]; + T value = input[index]; + auto flag = flags[index]; - using wreduce_t = rocprim::warp_reduce; - __shared__ typename wreduce_t::storage_type storage[warps_no]; - wreduce_t().tail_segmented_reduce(value, value, flag, storage[warp_id]); + using wreduce_t = rocprim::warp_reduce; + __shared__ typename wreduce_t::storage_type storage[warps_no]; + wreduce_t().tail_segmented_reduce(value, value, flag, storage[warp_id]); - output[index] = value; + output[index] = value; + } } #endif // TEST_WARP_REDUCE_KERNELS_HPP_ diff --git a/test/rocprim/test_warp_scan.cpp b/test/rocprim/test_warp_scan.cpp index ced4932a4..5170f2b73 100644 --- a/test/rocprim/test_warp_scan.cpp +++ b/test/rocprim/test_warp_scan.cpp @@ -41,7 +41,9 @@ struct Integral; #define warp_params WarpParamsIntegral #define name_suffix Integral -#include "test_warp_scan.hpp" +#if !_CLANGD + #include "test_warp_scan.hpp" +#endif #undef suite_name #undef warp_params @@ -52,4 +54,6 @@ struct Floating; #define warp_params WarpParamsFloating #define name_suffix Floating -#include "test_warp_scan.hpp" +#if !_CLANGD + #include "test_warp_scan.hpp" +#endif diff --git a/test/rocprim/test_warp_scan.hpp b/test/rocprim/test_warp_scan.hpp index e4bb9be1b..0ac49a501 100644 --- a/test/rocprim/test_warp_scan.hpp +++ b/test/rocprim/test_warp_scan.hpp @@ -37,9 +37,13 @@ #include #include +#if _CLANGD + #include "test_warp_scan.cpp" +#endif + test_suite_type_def(suite_name, name_suffix) -typed_test_suite_def(RocprimWarpScanTests, name_suffix, warp_params); + typed_test_suite_def(RocprimWarpScanTests, name_suffix, warp_params); typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScan) { @@ -61,36 +65,42 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScan) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; const unsigned int grid_size = 4; + const size_t size = block_size * grid_size; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data @@ -104,8 +114,8 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScan) acc_type accumulator(0); for(size_t j = 0; j < logical_warp_size; j++) { - auto idx = i * logical_warp_size + j; - accumulator = binary_op_host(input[idx], accumulator); + auto idx = i * logical_warp_size + j; + accumulator = binary_op_host(input[idx], accumulator); expected[idx] = static_cast(accumulator); } } @@ -115,7 +125,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScan) common::device_ptr device_output(output.size()); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_inclusive_scan_kernel), @@ -126,7 +136,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScan) device_input.get(), device_output.get()); } - else if (current_device_warp_size == ws64) + else if(current_device_warp_size == ws64) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_inclusive_scan_kernel), @@ -147,7 +157,6 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScan) // Validating results test_utils::assert_near(output, expected, test_utils::precision * logical_warp_size); } - } typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanInitialValue) @@ -186,12 +195,13 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanInitialValue) const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; const unsigned int grid_size = 4; - const size_t size = block_size * grid_size; + + const size_t size = block_size * grid_size; // Check if warp size is supported if((logical_warp_size > current_device_warp_size) || (current_device_warp_size != ws32 - && current_device_warp_size != ws64)) // Only WarpSize 32 and 64 is supported + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " "%u. Skipping test\n", @@ -291,36 +301,42 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanReduce) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; const unsigned int grid_size = 4; + const size_t size = block_size * grid_size; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data @@ -336,11 +352,11 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanReduce) acc_type accumulator(0); for(size_t j = 0; j < logical_warp_size; j++) { - auto idx = i * logical_warp_size + j; - accumulator = binary_op_host(input[idx],accumulator); + auto idx = i * logical_warp_size + j; + accumulator = binary_op_host(input[idx], accumulator); expected[idx] = static_cast(accumulator); } - expected_reductions[i] = expected[(i+1) * logical_warp_size - 1]; + expected_reductions[i] = expected[(i + 1) * logical_warp_size - 1]; } // Writing to device memory @@ -349,7 +365,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanReduce) common::device_ptr device_output_reductions(output_reductions.size()); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME( @@ -389,7 +405,6 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanReduce) expected_reductions, test_utils::precision * logical_warp_size); } - } typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanReduceInitialValue) @@ -428,12 +443,13 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanReduceInitialValu const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; const unsigned int grid_size = 4; - const size_t size = block_size * grid_size; + + const size_t size = block_size * grid_size; // Check if warp size is supported if((logical_warp_size > current_device_warp_size) || (current_device_warp_size != ws32 - && current_device_warp_size != ws64)) // Only WarpSize 32 and 64 is supported + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " "%u. Skipping test\n", @@ -461,14 +477,19 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanReduceInitialValu // Calculate expected results on host for(size_t i = 0; i < output.size() / logical_warp_size; i++) { - acc_type accumulator(initial_value); + acc_type accumulator = acc_type(initial_value); + acc_type reduction = input[i * logical_warp_size]; for(size_t j = 0; j < logical_warp_size; j++) { auto idx = i * logical_warp_size + j; accumulator = binary_op_host(input[idx], accumulator); expected[idx] = static_cast(accumulator); + if(j > 0) + { + reduction = binary_op_host(input[idx], reduction); + } } - expected_reductions[i] = expected[(i + 1) * logical_warp_size - 1]; + expected_reductions[i] = static_cast(reduction); } // Writing to device memory @@ -543,43 +564,49 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScan) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; const unsigned int grid_size = 4; + const size_t size = block_size * grid_size; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data std::vector input = test_utils::get_random_data_wrapped(size, 2, 50, seed_value); std::vector output(size); std::vector expected(input.size(), T(0)); - const T init = test_utils::get_random_value(0, 100, seed_value); + const T init = test_utils::get_random_value(0, 100, seed_value); // Calculate expected results on host for(size_t i = 0; i < input.size() / logical_warp_size; i++) @@ -588,8 +615,8 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScan) expected[i * logical_warp_size] = init; for(size_t j = 1; j < logical_warp_size; j++) { - auto idx = i * logical_warp_size + j; - accumulator = binary_op_host(input[idx-1], accumulator); + auto idx = i * logical_warp_size + j; + accumulator = binary_op_host(input[idx - 1], accumulator); expected[idx] = static_cast(accumulator); } } @@ -599,7 +626,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScan) common::device_ptr device_output(output.size()); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_exclusive_scan_kernel), @@ -611,7 +638,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScan) device_output.get(), init); } - else if (current_device_warp_size == ws64) + else if(current_device_warp_size == ws64) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_exclusive_scan_kernel), @@ -633,7 +660,6 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScan) // Validating results test_utils::assert_near(output, expected, test_utils::precision * logical_warp_size); } - } typed_test_def(RocprimWarpScanTests, name_suffix, Broadcast) @@ -672,12 +698,13 @@ typed_test_def(RocprimWarpScanTests, name_suffix, Broadcast) const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; const unsigned int grid_size = 4; - const size_t size = block_size * grid_size; + + const size_t size = block_size * grid_size; // Check if warp size is supported if((logical_warp_size > current_device_warp_size) || (current_device_warp_size != ws32 - && current_device_warp_size != ws64)) // Only WarpSize 32 and 64 is supported + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { GTEST_SKIP() << "Unsupported test warp size/computed block size: " << logical_warp_size << "/" << block_size @@ -782,12 +809,13 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveScanWoInit) const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; const unsigned int grid_size = 4; - const size_t size = block_size * grid_size; + + const size_t size = block_size * grid_size; // Check if warp size is supported if((logical_warp_size > current_device_warp_size) || (current_device_warp_size != ws32 - && current_device_warp_size != ws64)) // Only WarpSize 32 and 64 is supported + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " "%d. Skipping test\n", @@ -894,36 +922,42 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; const unsigned int grid_size = 4; + const size_t size = block_size * grid_size; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data @@ -932,7 +966,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) std::vector output_reductions(size / logical_warp_size); std::vector expected(input.size(), T(0)); std::vector expected_reductions(output_reductions.size(), T(0)); - const T init = test_utils::get_random_value(0, 100, seed_value); + const T init = test_utils::get_random_value(0, 100, seed_value); // Calculate expected results on host for(size_t i = 0; i < input.size() / logical_warp_size; i++) @@ -941,15 +975,15 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) expected[i * logical_warp_size] = init; for(size_t j = 1; j < logical_warp_size; j++) { - auto idx = i * logical_warp_size + j; - accumulator = binary_op_host(input[idx-1], accumulator); + auto idx = i * logical_warp_size + j; + accumulator = binary_op_host(input[idx - 1], accumulator); expected[idx] = static_cast(accumulator); } acc_type accumulator_reductions(0); for(size_t j = 0; j < logical_warp_size; j++) { - auto idx = i * logical_warp_size + j; + auto idx = i * logical_warp_size + j; accumulator_reductions = binary_op_host(input[idx], accumulator_reductions); expected_reductions[i] = static_cast(accumulator_reductions); } @@ -961,7 +995,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) common::device_ptr device_output_reductions(output_reductions.size()); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME( @@ -975,7 +1009,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) device_output_reductions.get(), init); } - else if (current_device_warp_size == ws64) + else if(current_device_warp_size == ws64) { hipLaunchKernelGGL( HIP_KERNEL_NAME( @@ -1002,7 +1036,6 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScan) expected_reductions, test_utils::precision * logical_warp_size); } - } typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScanWoInit) @@ -1041,12 +1074,13 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ExclusiveReduceScanWoInit) const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; const unsigned int grid_size = 4; - const size_t size = block_size * grid_size; + + const size_t size = block_size * grid_size; // Check if warp size is supported if((logical_warp_size > current_device_warp_size) || (current_device_warp_size != ws32 - && current_device_warp_size != ws64)) // Only WarpSize 32 and 64 is supported + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " "%d. Skipping test\n", @@ -1170,36 +1204,42 @@ typed_test_def(RocprimWarpScanTests, name_suffix, Scan) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; const unsigned int grid_size = 4; + const size_t size = block_size * grid_size; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data @@ -1208,22 +1248,22 @@ typed_test_def(RocprimWarpScanTests, name_suffix, Scan) std::vector output_exclusive(size); std::vector expected_inclusive(output_inclusive.size(), T(0)); std::vector expected_exclusive(output_exclusive.size(), T(0)); - const T init = test_utils::get_random_value(0, 100, seed_value); + const T init = test_utils::get_random_value(0, 100, seed_value); // Calculate expected results on host for(size_t i = 0; i < input.size() / logical_warp_size; i++) { acc_type accumulator_inclusive(0); - acc_type accumulator_exclusive = init; + acc_type accumulator_exclusive = init; expected_exclusive[i * logical_warp_size] = init; for(size_t j = 0; j < logical_warp_size; j++) { - auto idx = i * logical_warp_size + j; - accumulator_inclusive = binary_op_host(input[idx], accumulator_inclusive); + auto idx = i * logical_warp_size + j; + accumulator_inclusive = binary_op_host(input[idx], accumulator_inclusive); expected_inclusive[idx] = static_cast(accumulator_inclusive); if(j > 0) { - accumulator_exclusive = binary_op_host(input[idx-1], accumulator_exclusive); + accumulator_exclusive = binary_op_host(input[idx - 1], accumulator_exclusive); expected_exclusive[idx] = static_cast(accumulator_exclusive); } } @@ -1235,7 +1275,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, Scan) common::device_ptr device_exclusive_output(output_exclusive.size()); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_scan_kernel), @@ -1248,7 +1288,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, Scan) device_exclusive_output.get(), init); } - else if (current_device_warp_size == ws64) + else if(current_device_warp_size == ws64) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_scan_kernel), @@ -1299,36 +1339,42 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ScanReduce) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; const unsigned int grid_size = 4; + const size_t size = block_size * grid_size; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data @@ -1339,7 +1385,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ScanReduce) std::vector expected_inclusive(output_inclusive.size(), T(0)); std::vector expected_exclusive(output_exclusive.size(), T(0)); std::vector expected_reductions(output_reductions.size(), T(0)); - const T init = test_utils::get_random_value(0, 100, seed_value); + const T init = test_utils::get_random_value(0, 100, seed_value); // Calculate expected results on host for(size_t i = 0; i < input.size() / logical_warp_size; i++) @@ -1349,16 +1395,16 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ScanReduce) expected_exclusive[i * logical_warp_size] = init; for(size_t j = 0; j < logical_warp_size; j++) { - auto idx = i * logical_warp_size + j; - accumulator_inclusive = binary_op_host(input[idx], accumulator_inclusive); + auto idx = i * logical_warp_size + j; + accumulator_inclusive = binary_op_host(input[idx], accumulator_inclusive); expected_inclusive[idx] = static_cast(accumulator_inclusive); if(j > 0) { - accumulator_exclusive = binary_op_host(input[idx-1], accumulator_exclusive); + accumulator_exclusive = binary_op_host(input[idx - 1], accumulator_exclusive); expected_exclusive[idx] = static_cast(accumulator_exclusive); } } - expected_reductions[i] = expected_inclusive[(i+1) * logical_warp_size - 1]; + expected_reductions[i] = expected_inclusive[(i + 1) * logical_warp_size - 1]; } // Writing to device memory @@ -1368,7 +1414,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ScanReduce) common::device_ptr device_output_reductions(output_reductions.size()); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_scan_reduce_kernel), @@ -1382,7 +1428,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ScanReduce) device_output_reductions.get(), init); } - else if (current_device_warp_size == ws64) + else if(current_device_warp_size == ws64) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_scan_reduce_kernel), @@ -1416,7 +1462,6 @@ typed_test_def(RocprimWarpScanTests, name_suffix, ScanReduce) expected_reductions, test_utils::precision * logical_warp_size); } - } typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanCustomType) @@ -1437,36 +1482,42 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanCustomType) static constexpr size_t ws64 = size_t(ROCPRIM_WARP_SIZE_64); // Block size of warp size 32 - static constexpr size_t block_size_ws32 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws32, logical_warp_size * 4) - : rocprim::max((ws32/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws32 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws32, logical_warp_size * 4) + : rocprim::max((ws32 / logical_warp_size), 1) * logical_warp_size; // Block size of warp size 64 - static constexpr size_t block_size_ws64 = - rocprim::detail::is_power_of_two(logical_warp_size) - ? rocprim::max(ws64, logical_warp_size * 4) - : rocprim::max((ws64/logical_warp_size), 1) * logical_warp_size; + static constexpr size_t block_size_ws64 + = rocprim::detail::is_power_of_two(logical_warp_size) + ? rocprim::max(ws64, logical_warp_size * 4) + : rocprim::max((ws64 / logical_warp_size), 1) * logical_warp_size; unsigned int current_device_warp_size; HIP_CHECK(::rocprim::host_warp_size(device_id, current_device_warp_size)); const size_t block_size = current_device_warp_size == ws32 ? block_size_ws32 : block_size_ws64; const unsigned int grid_size = 4; + const size_t size = block_size * grid_size; // Check if warp size is supported - if( (logical_warp_size > current_device_warp_size) || - (current_device_warp_size != ws32 && current_device_warp_size != ws64) ) // Only WarpSize 32 and 64 is supported + if((logical_warp_size > current_device_warp_size) + || (current_device_warp_size != ws32 + && current_device_warp_size != ws64)) // Only VirtualWaveSize 32 and 64 is supported { - printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: %u. Skipping test\n", - logical_warp_size, block_size, current_device_warp_size); + printf("Unsupported test warp size/computed block size: %zu/%zu. Current device warp size: " + "%u. Skipping test\n", + logical_warp_size, + block_size, + current_device_warp_size); GTEST_SKIP(); } for(size_t seed_index = 0; seed_index < number_of_runs; seed_index++) { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data @@ -1492,7 +1543,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanCustomType) common::custom_type accumulator(acc_type(0)); for(size_t j = 0; j < logical_warp_size; j++) { - auto idx = i * logical_warp_size + j; + auto idx = i * logical_warp_size + j; accumulator = static_cast>(input[idx]) + accumulator; expected[idx] = static_cast(accumulator); @@ -1504,7 +1555,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanCustomType) common::device_ptr device_output(output.size()); // Launching kernel - if (current_device_warp_size == ws32) + if(current_device_warp_size == ws32) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_inclusive_scan_kernel), @@ -1515,7 +1566,7 @@ typed_test_def(RocprimWarpScanTests, name_suffix, InclusiveScanCustomType) device_input.get(), device_output.get()); } - else if (current_device_warp_size == ws64) + else if(current_device_warp_size == ws64) { hipLaunchKernelGGL( HIP_KERNEL_NAME(warp_inclusive_scan_kernel), diff --git a/test/rocprim/test_warp_scan.kernels.hpp b/test/rocprim/test_warp_scan.kernels.hpp index 199a1a9f8..9de4049d5 100644 --- a/test/rocprim/test_warp_scan.kernels.hpp +++ b/test/rocprim/test_warp_scan.kernels.hpp @@ -24,170 +24,188 @@ #define TEST_SCAN_REDUCE_KERNELS_HPP_ template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_inclusive_scan_kernel(T* device_input, T* device_output) { - constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - T value = device_input[index]; + T value = device_input[index]; - using wscan_t = rocprim::warp_scan; - __shared__ typename wscan_t::storage_type storage[warps_no]; - wscan_t().inclusive_scan(value, value, storage[warp_id]); + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[warps_no]; + wscan_t().inclusive_scan(value, value, storage[warp_id]); - device_output[index] = value; + device_output[index] = value; + } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_inclusive_scan_initial_value_kernel(T* device_input, T* device_output, T initial_value) { - constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - T value = device_input[index]; + T value = device_input[index]; - using wscan_t = rocprim::warp_scan; - __shared__ typename wscan_t::storage_type storage[warps_no]; - wscan_t().inclusive_scan(value, value, storage[warp_id], initial_value); + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[warps_no]; + wscan_t().inclusive_scan(value, value, storage[warp_id], initial_value); - device_output[index] = value; + device_output[index] = value; + } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_inclusive_scan_reduce_kernel(T* device_input, T* device_output, T* device_output_reductions) { - constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + ( blockIdx.x * BlockSize ); - - T value = device_input[index]; - T reduction; - - using wscan_t = rocprim::warp_scan; - __shared__ typename wscan_t::storage_type storage[warps_no]; - wscan_t().inclusive_scan(value, value, reduction, storage[warp_id]); - - device_output[index] = value; - if((threadIdx.x % LogicalWarpSize) == 0) + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) { - device_output_reductions[index / LogicalWarpSize] = reduction; + constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * BlockSize); + + T value = device_input[index]; + T reduction; + + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[warps_no]; + wscan_t().inclusive_scan(value, value, reduction, storage[warp_id]); + + device_output[index] = value; + if((threadIdx.x % LogicalWarpSize) == 0) + { + device_output_reductions[index / LogicalWarpSize] = reduction; + } } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_inclusive_scan_reduce_initial_value_kernel(T* device_input, T* device_output, T* device_output_reductions, T initial_value) { - constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * BlockSize); - - T value = device_input[index]; - T reduction; - - using wscan_t = rocprim::warp_scan; - __shared__ typename wscan_t::storage_type storage[warps_no]; - wscan_t().inclusive_scan(value, value, reduction, storage[warp_id], initial_value); - - device_output[index] = value; - if((threadIdx.x % LogicalWarpSize) == 0) + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) { - device_output_reductions[index / LogicalWarpSize] = reduction; + constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * BlockSize); + + T value = device_input[index]; + T reduction; + + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[warps_no]; + wscan_t().inclusive_scan(value, value, reduction, storage[warp_id], initial_value); + + device_output[index] = value; + if((threadIdx.x % LogicalWarpSize) == 0) + { + device_output_reductions[index / LogicalWarpSize] = reduction; + } } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_exclusive_scan_kernel(T* device_input, T* device_output, T init) { - constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - T value = device_input[index]; + T value = device_input[index]; - using wscan_t = rocprim::warp_scan; - __shared__ typename wscan_t::storage_type storage[warps_no]; - wscan_t().exclusive_scan(value, value, init, storage[warp_id]); + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[warps_no]; + wscan_t().exclusive_scan(value, value, init, storage[warp_id]); - device_output[index] = value; + device_output[index] = value; + } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_exclusive_scan_reduce_kernel(T* device_input, T* device_output, T* device_output_reductions, T init) { - constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - - T value = device_input[index]; - T reduction; - - using wscan_t = rocprim::warp_scan; - __shared__ typename wscan_t::storage_type storage[warps_no]; - wscan_t().exclusive_scan(value, value, init, reduction, storage[warp_id]); - device_output[index] = value; - if((threadIdx.x % LogicalWarpSize) == 0) + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) { - device_output_reductions[index / LogicalWarpSize] = reduction; + constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + + T value = device_input[index]; + T reduction; + + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[warps_no]; + wscan_t().exclusive_scan(value, value, init, reduction, storage[warp_id]); + + device_output[index] = value; + if((threadIdx.x % LogicalWarpSize) == 0) + { + device_output_reductions[index / LogicalWarpSize] = reduction; + } } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_broadcast_kernel(T* device_input, T* device_output) { - const unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - const unsigned int warp_id = index / LogicalWarpSize; - const unsigned int src_lane = warp_id % LogicalWarpSize; + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + const unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + const unsigned int warp_id = index / LogicalWarpSize; + const unsigned int src_lane = warp_id % LogicalWarpSize; - T value = device_input[index]; + T value = device_input[index]; - using wscan_t = rocprim::warp_scan; - __shared__ typename wscan_t::storage_type storage; - value = wscan_t().broadcast(value, src_lane, storage); + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage; + value = wscan_t().broadcast(value, src_lane, storage); - device_output[index] = value; + device_output[index] = value; + } } template __global__ __launch_bounds__(BlockSize) void warp_exclusive_scan_wo_init_kernel(T* device_input, T* device_output) { - static constexpr unsigned int block_warps_no = BlockSize / LogicalWarpSize; + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + static constexpr unsigned int block_warps_no = BlockSize / LogicalWarpSize; - const unsigned int global_index = threadIdx.x + (blockIdx.x * blockDim.x); - const unsigned int block_warp_id = threadIdx.x / LogicalWarpSize; + const unsigned int global_index = threadIdx.x + (blockIdx.x * blockDim.x); + const unsigned int block_warp_id = threadIdx.x / LogicalWarpSize; - T value = device_input[global_index]; + T value = device_input[global_index]; - using wscan_t = rocprim::warp_scan; - __shared__ typename wscan_t::storage_type storage[block_warps_no]; - wscan_t().exclusive_scan(value, value, storage[block_warp_id]); + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[block_warps_no]; + wscan_t().exclusive_scan(value, value, storage[block_warp_id]); - device_output[global_index] = value; + device_output[global_index] = value; + } } template @@ -196,75 +214,83 @@ void warp_exclusive_scan_reduce_wo_init_kernel(T* device_input, T* device_output, T* device_output_reductions) { - static constexpr unsigned int block_warps_no = BlockSize / LogicalWarpSize; + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + static constexpr unsigned int block_warps_no = BlockSize / LogicalWarpSize; - const unsigned int global_index = threadIdx.x + (blockIdx.x * blockDim.x); - const unsigned int block_warp_id = threadIdx.x / LogicalWarpSize; - const unsigned int lane_id = threadIdx.x % LogicalWarpSize; - const unsigned int global_warp_id = global_index / LogicalWarpSize; + const unsigned int global_index = threadIdx.x + (blockIdx.x * blockDim.x); + const unsigned int block_warp_id = threadIdx.x / LogicalWarpSize; + const unsigned int lane_id = threadIdx.x % LogicalWarpSize; + const unsigned int global_warp_id = global_index / LogicalWarpSize; - T value = device_input[global_index]; - T reduction; + T value = device_input[global_index]; + T reduction; - using wscan_t = rocprim::warp_scan; - __shared__ typename wscan_t::storage_type storage[block_warps_no]; - wscan_t().exclusive_scan(value, value, storage[block_warp_id], reduction); + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[block_warps_no]; + wscan_t().exclusive_scan(value, value, storage[block_warp_id], reduction); - device_output[global_index] = value; - if(lane_id == 0) - { - device_output_reductions[global_warp_id] = reduction; + device_output[global_index] = value; + if(lane_id == 0) + { + device_output_reductions[global_warp_id] = reduction; + } } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_scan_kernel(T* device_input, T* device_inclusive_output, T* device_exclusive_output, T init) { - constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) + { + constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - T input = device_input[index]; - T inclusive_output, exclusive_output; + T input = device_input[index]; + T inclusive_output, exclusive_output; - using wscan_t = rocprim::warp_scan; - __shared__ typename wscan_t::storage_type storage[warps_no]; - wscan_t().scan(input, inclusive_output, exclusive_output, init, storage[warp_id]); + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[warps_no]; + wscan_t().scan(input, inclusive_output, exclusive_output, init, storage[warp_id]); - device_inclusive_output[index] = inclusive_output; - device_exclusive_output[index] = exclusive_output; + device_inclusive_output[index] = inclusive_output; + device_exclusive_output[index] = exclusive_output; + } } template -__global__ -__launch_bounds__(BlockSize) +__global__ __launch_bounds__(BlockSize) void warp_scan_reduce_kernel(T* device_input, T* device_inclusive_output, T* device_exclusive_output, T* device_output_reductions, T init) { - constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; - const unsigned int warp_id = rocprim::detail::logical_warp_id(); - unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); - - T input = device_input[index]; - T inclusive_output, exclusive_output, reduction; - - using wscan_t = rocprim::warp_scan; - __shared__ typename wscan_t::storage_type storage[warps_no]; - wscan_t().scan(input, inclusive_output, exclusive_output, init, reduction, storage[warp_id]); - - device_inclusive_output[index] = inclusive_output; - device_exclusive_output[index] = exclusive_output; - if((threadIdx.x % LogicalWarpSize) == 0) + if constexpr(LogicalWarpSize <= rocprim::arch::wavefront::max_size()) { - device_output_reductions[index / LogicalWarpSize] = reduction; + constexpr unsigned int warps_no = BlockSize / LogicalWarpSize; + const unsigned int warp_id = rocprim::detail::logical_warp_id(); + unsigned int index = threadIdx.x + (blockIdx.x * blockDim.x); + + T input = device_input[index]; + T inclusive_output, exclusive_output, reduction; + + using wscan_t = rocprim::warp_scan; + __shared__ typename wscan_t::storage_type storage[warps_no]; + wscan_t() + .scan(input, inclusive_output, exclusive_output, init, reduction, storage[warp_id]); + + device_inclusive_output[index] = inclusive_output; + device_exclusive_output[index] = exclusive_output; + if((threadIdx.x % LogicalWarpSize) == 0) + { + device_output_reductions[index / LogicalWarpSize] = reduction; + } } } diff --git a/test/rocprim/test_warp_sort.kernels.hpp b/test/rocprim/test_warp_sort.kernels.hpp index 2fe3f4763..48d3f3f83 100644 --- a/test/rocprim/test_warp_sort.kernels.hpp +++ b/test/rocprim/test_warp_sort.kernels.hpp @@ -60,20 +60,23 @@ template __device__ auto test_hip_warp_sort_impl(KeyType* device_key_output) - -> std::enable_if_t<(LogicalWarpSize <= ::rocprim::arch::wavefront::min_size())> + -> std::enable_if_t<(LogicalWarpSize <= ::rocprim::arch::wavefront::max_size())> { - const unsigned int lid = threadIdx.x; - const unsigned int block_offset = blockIdx.x * ItemsPerThread * BlockSize; + if(LogicalWarpSize <= ::rocprim::arch::wavefront::size()) + { + const unsigned int lid = threadIdx.x; + const unsigned int block_offset = blockIdx.x * ItemsPerThread * BlockSize; - KeyType keys[ItemsPerThread]; - ::rocprim::block_load_direct_warp_striped(lid, - device_key_output + block_offset, - keys); + KeyType keys[ItemsPerThread]; + ::rocprim::block_load_direct_warp_striped(lid, + device_key_output + block_offset, + keys); - rocprim::warp_sort wsort; - wsort.sort(keys); + rocprim::warp_sort wsort; + wsort.sort(keys); - ::rocprim::block_store_direct_blocked(lid, device_key_output + block_offset, keys); + ::rocprim::block_store_direct_blocked(lid, device_key_output + block_offset, keys); + } } template __device__ auto test_hip_warp_sort_impl(KeyType*) - -> std::enable_if_t<(LogicalWarpSize > ::rocprim::arch::wavefront::min_size())> + -> std::enable_if_t<(LogicalWarpSize > ::rocprim::arch::wavefront::max_size())> {} template __device__ auto test_hip_sort_key_value_impl(KeyType* device_key_output, ValueType* device_value_output) - -> std::enable_if_t<(LogicalWarpSize <= ::rocprim::arch::wavefront::min_size())> + -> std::enable_if_t<(LogicalWarpSize <= ::rocprim::arch::wavefront::max_size())> { - const unsigned int lid = threadIdx.x; - const unsigned int block_offset = blockIdx.x * ItemsPerThread * BlockSize; + if(LogicalWarpSize <= ::rocprim::arch::wavefront::size()) + { + const unsigned int lid = threadIdx.x; + const unsigned int block_offset = blockIdx.x * ItemsPerThread * BlockSize; - KeyType keys[ItemsPerThread]; - ValueType values[ItemsPerThread]; - ::rocprim::block_load_direct_warp_striped(lid, - device_key_output + block_offset, - keys); - ::rocprim::block_load_direct_warp_striped(lid, - device_value_output + block_offset, - values); + KeyType keys[ItemsPerThread]; + ValueType values[ItemsPerThread]; + ::rocprim::block_load_direct_warp_striped(lid, + device_key_output + block_offset, + keys); + ::rocprim::block_load_direct_warp_striped(lid, + device_value_output + + block_offset, + values); - rocprim::warp_sort wsort; - wsort.sort(keys, values); + rocprim::warp_sort wsort; + wsort.sort(keys, values); - ::rocprim::block_store_direct_blocked(lid, device_key_output + block_offset, keys); - ::rocprim::block_store_direct_blocked(lid, device_value_output + block_offset, values); + ::rocprim::block_store_direct_blocked(lid, device_key_output + block_offset, keys); + ::rocprim::block_store_direct_blocked(lid, device_value_output + block_offset, values); + } } template __device__ auto test_hip_sort_key_value_impl(KeyType*, ValueType*) - -> std::enable_if_t<(LogicalWarpSize > ::rocprim::arch::wavefront::min_size())> + -> std::enable_if_t<(LogicalWarpSize > ::rocprim::arch::wavefront::max_size())> {} template #include -template< - class T, - unsigned int ItemsPerThread, - unsigned int WarpSize, - ::rocprim::warp_store_method Method -> +template struct Params { using type = T; static constexpr unsigned int items_per_thread = ItemsPerThread; - static constexpr unsigned int warp_size = WarpSize; + static constexpr unsigned int warp_size = VirtualWaveSize; static constexpr ::rocprim::warp_store_method method = Method; }; diff --git a/toolchain-windows.cmake b/toolchain-windows.cmake index fe1cae6b5..974fb7b33 100644 --- a/toolchain-windows.cmake +++ b/toolchain-windows.cmake @@ -25,7 +25,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWIN32 -D_CRT_SECURE_NO_WARNINGS") # flags for clang direct use # -Wno-ignored-attributes to avoid warning: __declspec attribute 'dllexport' is not supported [-Wignored-attributes] which is used by msvc compiler -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -fms-extensions -fms-compatibility -Wno-ignored-attributes") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fms-extensions -fms-compatibility -Wno-ignored-attributes") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_AMD__ -D__HIP_ROCclr__")