diff --git a/cub/agent/agent_reduce_by_key.cuh b/cub/agent/agent_reduce_by_key.cuh index 48aa17a87b..abc146579d 100644 --- a/cub/agent/agent_reduce_by_key.cuh +++ b/cub/agent/agent_reduce_by_key.cuh @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. - * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -13,10 +13,10 @@ * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND @@ -27,516 +27,682 @@ ******************************************************************************/ /** - * \file - * cub::AgentReduceByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key. + * @file cub::AgentReduceByKey implements a stateful abstraction of CUDA thread + * blocks for participating in device-wide reduce-value-by-key. */ #pragma once #include -#include "single_pass_scan_operators.cuh" -#include "../block/block_load.cuh" -#include "../block/block_store.cuh" -#include "../block/block_scan.cuh" -#include "../block/block_discontinuity.cuh" -#include "../config.cuh" -#include "../iterator/cache_modified_input_iterator.cuh" -#include "../iterator/constant_input_iterator.cuh" +#include +#include +#include +#include +#include +#include +#include +#include CUB_NAMESPACE_BEGIN - /****************************************************************************** * Tuning policy types ******************************************************************************/ /** - * Parameterizable tuning policy type for AgentReduceByKey + * @brief Parameterizable tuning policy type for AgentReduceByKey + * + * @tparam _BLOCK_THREADS + * Threads per thread block + * + * @tparam _ITEMS_PER_THREAD + * Items per thread (per tile of input) + * + * @tparam _LOAD_ALGORITHM + * The BlockLoad algorithm to use + * + * @tparam _LOAD_MODIFIER + * Cache load modifier for reading input elements + * + * @tparam _SCAN_ALGORITHM + * The BlockScan algorithm to use */ -template < - int _BLOCK_THREADS, ///< Threads per thread block - int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) - BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use - CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements - BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use +template struct AgentReduceByKeyPolicy { - enum - { - BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block - ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) - }; + ///< Threads per thread block + static constexpr int BLOCK_THREADS = _BLOCK_THREADS; - static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use - static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements - static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use -}; + ///< Items per thread (per tile of input) + static constexpr int ITEMS_PER_THREAD = _ITEMS_PER_THREAD; + ///< The BlockLoad algorithm to use + static constexpr BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; + + ///< Cache load modifier for reading input elements + static constexpr CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; + + ///< The BlockScan algorithm to use + static constexpr const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; +}; /****************************************************************************** * Thread block abstractions ******************************************************************************/ /** - * \brief AgentReduceByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key + * @brief AgentReduceByKey implements a stateful abstraction of CUDA thread + * blocks for participating in device-wide reduce-value-by-key + * + * @tparam AgentReduceByKeyPolicyT + * Parameterized AgentReduceByKeyPolicy tuning policy type + * + * @tparam KeysInputIteratorT + * Random-access input iterator type for keys + * + * @tparam UniqueOutputIteratorT + * Random-access output iterator type for keys + * + * @tparam ValuesInputIteratorT + * Random-access input iterator type for values + * + * @tparam AggregatesOutputIteratorT + * Random-access output iterator type for values + * + * @tparam NumRunsOutputIteratorT + * Output iterator type for recording number of items selected + * + * @tparam EqualityOpT + * KeyT equality operator type + * + * @tparam ReductionOpT + * ValueT reduction operator type + * + * @tparam OffsetT + * Signed integer type for global offsets */ -template < - typename AgentReduceByKeyPolicyT, ///< Parameterized AgentReduceByKeyPolicy tuning policy type - typename KeysInputIteratorT, ///< Random-access input iterator type for keys - typename UniqueOutputIteratorT, ///< Random-access output iterator type for keys - typename ValuesInputIteratorT, ///< Random-access input iterator type for values - typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values - typename NumRunsOutputIteratorT, ///< Output iterator type for recording number of items selected - typename EqualityOpT, ///< KeyT equality operator type - typename ReductionOpT, ///< ValueT reduction operator type - typename OffsetT> ///< Signed integer type for global offsets +template struct AgentReduceByKey { - //--------------------------------------------------------------------- - // Types and constants - //--------------------------------------------------------------------- + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- - // The input keys type - using KeyInputT = cub::detail::value_t; + // The input keys type + using KeyInputT = cub::detail::value_t; - // The output keys type - using KeyOutputT = - cub::detail::non_void_value_t; + // The output keys type + using KeyOutputT = + cub::detail::non_void_value_t; - // The input values type - using ValueInputT = cub::detail::value_t; + // The input values type + using ValueInputT = cub::detail::value_t; - // The output values type - using ValueOutputT = - cub::detail::non_void_value_t; + // Tuple type for scanning (pairs accumulated segment-value with + // segment-index) + using OffsetValuePairT = KeyValuePair; - // Tuple type for scanning (pairs accumulated segment-value with segment-index) - using OffsetValuePairT = KeyValuePair; - - // Tuple type for pairing keys and values - using KeyValuePairT = KeyValuePair; - - // Tile status descriptor interface type - using ScanTileStateT = ReduceByKeyScanTileState; - - // Guarded inequality functor - template - struct GuardedInequalityWrapper - { - _EqualityOpT op; ///< Wrapped equality operator - int num_remaining; ///< Items remaining + // Tuple type for pairing keys and values + using KeyValuePairT = KeyValuePair; - /// Constructor - __host__ __device__ __forceinline__ - GuardedInequalityWrapper(_EqualityOpT op, int num_remaining) : op(op), num_remaining(num_remaining) {} + // Tile status descriptor interface type + using ScanTileStateT = ReduceByKeyScanTileState; - /// Boolean inequality operator, returns (a != b) - template - __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b, int idx) const - { - if (idx < num_remaining) - return !op(a, b); // In bounds + // Guarded inequality functor + template + struct GuardedInequalityWrapper + { + /// Wrapped equality operator + _EqualityOpT op; - // Return true if first out-of-bounds item, false otherwise - return (idx == num_remaining); - } - }; + /// Items remaining + int num_remaining; + /// Constructor + __host__ __device__ __forceinline__ + GuardedInequalityWrapper(_EqualityOpT op, int num_remaining) + : op(op) + , num_remaining(num_remaining) + {} - // Constants - enum + /// Boolean inequality operator, returns (a != b) + template + __host__ __device__ __forceinline__ bool operator()(const T &a, + const T &b, + int idx) const { - BLOCK_THREADS = AgentReduceByKeyPolicyT::BLOCK_THREADS, - ITEMS_PER_THREAD = AgentReduceByKeyPolicyT::ITEMS_PER_THREAD, - TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, - TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1), - - // Whether or not the scan operation has a zero-valued identity value (true if we're performing addition on a primitive type) - HAS_IDENTITY_ZERO = (std::is_same::value) && - (Traits::PRIMITIVE), - }; - - // Cache-modified Input iterator wrapper type (for applying cache modifier) for keys - // Wrap the native input pointer with CacheModifiedValuesInputIterator - // or directly use the supplied input iterator type - using WrappedKeysInputIteratorT = cub::detail::conditional_t< - std::is_pointer::value, - CacheModifiedInputIterator, - KeysInputIteratorT>; - - // Cache-modified Input iterator wrapper type (for applying cache modifier) for values - // Wrap the native input pointer with CacheModifiedValuesInputIterator - // or directly use the supplied input iterator type - using WrappedValuesInputIteratorT = cub::detail::conditional_t< - std::is_pointer::value, - CacheModifiedInputIterator, - ValuesInputIteratorT>; - - // Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values - // Wrap the native input pointer with CacheModifiedValuesInputIterator - // or directly use the supplied input iterator type - using WrappedFixupInputIteratorT = cub::detail::conditional_t< - std::is_pointer::value, - CacheModifiedInputIterator, - AggregatesOutputIteratorT>; - - // Reduce-value-by-segment scan operator - using ReduceBySegmentOpT = ReduceBySegmentOp; - - // Parameterized BlockLoad type for keys - using BlockLoadKeysT = BlockLoad 1); + + // Whether or not the scan operation has a zero-valued identity value (true + // if we're performing addition on a primitive type) + static constexpr int HAS_IDENTITY_ZERO = + (std::is_same::value) && + (Traits::PRIMITIVE); + + // Cache-modified Input iterator wrapper type (for applying cache modifier) + // for keys Wrap the native input pointer with + // CacheModifiedValuesInputIterator or directly use the supplied input + // iterator type + using WrappedKeysInputIteratorT = cub::detail::conditional_t< + std::is_pointer::value, + CacheModifiedInputIterator, + KeysInputIteratorT>; + + // Cache-modified Input iterator wrapper type (for applying cache modifier) + // for values Wrap the native input pointer with + // CacheModifiedValuesInputIterator or directly use the supplied input + // iterator type + using WrappedValuesInputIteratorT = cub::detail::conditional_t< + std::is_pointer::value, + CacheModifiedInputIterator, + ValuesInputIteratorT>; + + // Cache-modified Input iterator wrapper type (for applying cache modifier) + // for fixup values Wrap the native input pointer with + // CacheModifiedValuesInputIterator or directly use the supplied input + // iterator type + using WrappedFixupInputIteratorT = cub::detail::conditional_t< + std::is_pointer::value, + CacheModifiedInputIterator, + AggregatesOutputIteratorT>; + + // Reduce-value-by-segment scan operator + using ReduceBySegmentOpT = ReduceBySegmentOp; + + // Parameterized BlockLoad type for keys + using BlockLoadKeysT = BlockLoad; + + // Parameterized BlockLoad type for values + using BlockLoadValuesT = BlockLoad; - // Parameterized BlockLoad type for values - using BlockLoadValuesT = BlockLoad; - - // Parameterized BlockDiscontinuity type for keys - using BlockDiscontinuityKeys = - BlockDiscontinuity; + // Parameterized BlockDiscontinuity type for keys + using BlockDiscontinuityKeys = BlockDiscontinuity; - // Parameterized BlockScan type - using BlockScanT = BlockScan; + // Parameterized BlockScan type + using BlockScanT = BlockScan; - // Callback type for obtaining tile prefix during block scan - using TilePrefixCallbackOpT = - TilePrefixCallbackOp; + // Callback type for obtaining tile prefix during block scan + using TilePrefixCallbackOpT = + TilePrefixCallbackOp; - // Key and value exchange types - typedef KeyOutputT KeyExchangeT[TILE_ITEMS + 1]; - typedef ValueOutputT ValueExchangeT[TILE_ITEMS + 1]; + // Key and value exchange types + typedef KeyOutputT KeyExchangeT[TILE_ITEMS + 1]; + typedef AccumT ValueExchangeT[TILE_ITEMS + 1]; - // Shared memory type for this thread block - union _TempStorage + // Shared memory type for this thread block + union _TempStorage + { + struct ScanStorage { - struct ScanStorage - { - typename BlockScanT::TempStorage scan; // Smem needed for tile scanning - typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback - typename BlockDiscontinuityKeys::TempStorage discontinuity; // Smem needed for discontinuity detection - } scan_storage; - - // Smem needed for loading keys - typename BlockLoadKeysT::TempStorage load_keys; - - // Smem needed for loading values - typename BlockLoadValuesT::TempStorage load_values; - - // Smem needed for compacting key value pairs(allows non POD items in this union) - Uninitialized raw_exchange; - }; - - // Alias wrapper allowing storage to be unioned - struct TempStorage : Uninitialized<_TempStorage> {}; - - - //--------------------------------------------------------------------- - // Per-thread fields - //--------------------------------------------------------------------- - - _TempStorage& temp_storage; ///< Reference to temp_storage - WrappedKeysInputIteratorT d_keys_in; ///< Input keys - UniqueOutputIteratorT d_unique_out; ///< Unique output keys - WrappedValuesInputIteratorT d_values_in; ///< Input values - AggregatesOutputIteratorT d_aggregates_out; ///< Output value aggregates - NumRunsOutputIteratorT d_num_runs_out; ///< Output pointer for total number of segments identified - EqualityOpT equality_op; ///< KeyT equality operator - ReductionOpT reduction_op; ///< Reduction operator - ReduceBySegmentOpT scan_op; ///< Reduce-by-segment scan operator - - - //--------------------------------------------------------------------- - // Constructor - //--------------------------------------------------------------------- - - // Constructor - __device__ __forceinline__ - AgentReduceByKey( - TempStorage& temp_storage, ///< Reference to temp_storage - KeysInputIteratorT d_keys_in, ///< Input keys - UniqueOutputIteratorT d_unique_out, ///< Unique output keys - ValuesInputIteratorT d_values_in, ///< Input values - AggregatesOutputIteratorT d_aggregates_out, ///< Output value aggregates - NumRunsOutputIteratorT d_num_runs_out, ///< Output pointer for total number of segments identified - EqualityOpT equality_op, ///< KeyT equality operator - ReductionOpT reduction_op) ///< ValueT reduction operator - : - temp_storage(temp_storage.Alias()), - d_keys_in(d_keys_in), - d_unique_out(d_unique_out), - d_values_in(d_values_in), - d_aggregates_out(d_aggregates_out), - d_num_runs_out(d_num_runs_out), - equality_op(equality_op), - reduction_op(reduction_op), - scan_op(reduction_op) - {} + // Smem needed for tile scanning + typename BlockScanT::TempStorage scan; + + // Smem needed for cooperative prefix callback + typename TilePrefixCallbackOpT::TempStorage prefix; + + // Smem needed for discontinuity detection + typename BlockDiscontinuityKeys::TempStorage discontinuity; + } scan_storage; + + // Smem needed for loading keys + typename BlockLoadKeysT::TempStorage load_keys; + + // Smem needed for loading values + typename BlockLoadValuesT::TempStorage load_values; + + // Smem needed for compacting key value pairs(allows non POD items in this + // union) + Uninitialized raw_exchange; + }; + + // Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> + {}; + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + /// Reference to temp_storage + _TempStorage &temp_storage; + + /// Input keys + WrappedKeysInputIteratorT d_keys_in; + + /// Unique output keys + UniqueOutputIteratorT d_unique_out; + + /// Input values + WrappedValuesInputIteratorT d_values_in; + + /// Output value aggregates + AggregatesOutputIteratorT d_aggregates_out; + + /// Output pointer for total number of segments identified + NumRunsOutputIteratorT d_num_runs_out; + + /// KeyT equality operator + EqualityOpT equality_op; + + /// Reduction operator + ReductionOpT reduction_op; + + /// Reduce-by-segment scan operator + ReduceBySegmentOpT scan_op; + + //--------------------------------------------------------------------- + // Constructor + //--------------------------------------------------------------------- + + /** + * @param temp_storage + * Reference to temp_storage + * + * @param d_keys_in + * Input keys + * + * @param d_unique_out + * Unique output keys + * + * @param d_values_in + * Input values + * + * @param d_aggregates_out + * Output value aggregates + * + * @param d_num_runs_out + * Output pointer for total number of segments identified + * + * @param equality_op + * KeyT equality operator + * + * @param reduction_op + * ValueT reduction operator + */ + __device__ __forceinline__ + AgentReduceByKey(TempStorage &temp_storage, + KeysInputIteratorT d_keys_in, + UniqueOutputIteratorT d_unique_out, + ValuesInputIteratorT d_values_in, + AggregatesOutputIteratorT d_aggregates_out, + NumRunsOutputIteratorT d_num_runs_out, + EqualityOpT equality_op, + ReductionOpT reduction_op) + : temp_storage(temp_storage.Alias()) + , d_keys_in(d_keys_in) + , d_unique_out(d_unique_out) + , d_values_in(d_values_in) + , d_aggregates_out(d_aggregates_out) + , d_num_runs_out(d_num_runs_out) + , equality_op(equality_op) + , reduction_op(reduction_op) + , scan_op(reduction_op) + {} + + //--------------------------------------------------------------------- + // Scatter utility methods + //--------------------------------------------------------------------- + + /** + * Directly scatter flagged items to output offsets + */ + __device__ __forceinline__ void + ScatterDirect(KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], + OffsetT (&segment_flags)[ITEMS_PER_THREAD], + OffsetT (&segment_indices)[ITEMS_PER_THREAD]) + { +// Scatter flagged keys and values +#pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + if (segment_flags[ITEM]) + { + d_unique_out[segment_indices[ITEM]] = scatter_items[ITEM].key; + d_aggregates_out[segment_indices[ITEM]] = scatter_items[ITEM].value; + } + } + } + + /** + * 2-phase scatter flagged items to output offsets + * + * The exclusive scan causes each head flag to be paired with the previous + * value aggregate: the scatter offsets must be decremented for value + * aggregates + */ + __device__ __forceinline__ void + ScatterTwoPhase(KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], + OffsetT (&segment_flags)[ITEMS_PER_THREAD], + OffsetT (&segment_indices)[ITEMS_PER_THREAD], + OffsetT num_tile_segments, + OffsetT num_tile_segments_prefix) + { + CTA_SYNC(); + +// Compact and scatter pairs +#pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) + { + if (segment_flags[ITEM]) + { + temp_storage.raw_exchange + .Alias()[segment_indices[ITEM] - num_tile_segments_prefix] = + scatter_items[ITEM]; + } + } + CTA_SYNC(); - //--------------------------------------------------------------------- - // Scatter utility methods - //--------------------------------------------------------------------- + for (int item = threadIdx.x; item < num_tile_segments; + item += BLOCK_THREADS) + { + KeyValuePairT pair = temp_storage.raw_exchange.Alias()[item]; + d_unique_out[num_tile_segments_prefix + item] = pair.key; + d_aggregates_out[num_tile_segments_prefix + item] = pair.value; + } + } + + /** + * Scatter flagged items + */ + __device__ __forceinline__ void + Scatter(KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], + OffsetT (&segment_flags)[ITEMS_PER_THREAD], + OffsetT (&segment_indices)[ITEMS_PER_THREAD], + OffsetT num_tile_segments, + OffsetT num_tile_segments_prefix) + { + // Do a one-phase scatter if (a) two-phase is disabled or (b) the average + // number of selected items per thread is less than one + if (TWO_PHASE_SCATTER && (num_tile_segments > BLOCK_THREADS)) + { + ScatterTwoPhase(scatter_items, + segment_flags, + segment_indices, + num_tile_segments, + num_tile_segments_prefix); + } + else + { + ScatterDirect(scatter_items, segment_flags, segment_indices); + } + } + + //--------------------------------------------------------------------- + // Cooperatively scan a device-wide sequence of tiles with other CTAs + //--------------------------------------------------------------------- + + /** + * @brief Process a tile of input (dynamic chained scan) + * + * @tparam IS_LAST_TILE + * Whether the current tile is the last tile + * + * @param num_remaining + * Number of global input items remaining (including this tile) + * + * @param tile_idx + * Tile index + * + * @param tile_offset + * Tile offset + * + * @param tile_state + * Global tile state descriptor + */ + template + __device__ __forceinline__ void ConsumeTile(OffsetT num_remaining, + int tile_idx, + OffsetT tile_offset, + ScanTileStateT &tile_state) + { + // Tile keys + KeyOutputT keys[ITEMS_PER_THREAD]; + + // Tile keys shuffled up + KeyOutputT prev_keys[ITEMS_PER_THREAD]; + + // Tile values + AccumT values[ITEMS_PER_THREAD]; + + // Segment head flags + OffsetT head_flags[ITEMS_PER_THREAD]; + + // Segment indices + OffsetT segment_indices[ITEMS_PER_THREAD]; + + // Zipped values and segment flags|indices + OffsetValuePairT scan_items[ITEMS_PER_THREAD]; + + // Zipped key value pairs for scattering + KeyValuePairT scatter_items[ITEMS_PER_THREAD]; + + // Load keys + if (IS_LAST_TILE) + { + BlockLoadKeysT(temp_storage.load_keys) + .Load(d_keys_in + tile_offset, keys, num_remaining); + } + else + { + BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys); + } - /** - * Directly scatter flagged items to output offsets - */ - __device__ __forceinline__ void ScatterDirect( - KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], - OffsetT (&segment_flags)[ITEMS_PER_THREAD], - OffsetT (&segment_indices)[ITEMS_PER_THREAD]) + // Load tile predecessor key in first thread + KeyOutputT tile_predecessor; + if (threadIdx.x == 0) { - // Scatter flagged keys and values - #pragma unroll - for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) - { - if (segment_flags[ITEM]) - { - d_unique_out[segment_indices[ITEM]] = scatter_items[ITEM].key; - d_aggregates_out[segment_indices[ITEM]] = scatter_items[ITEM].value; - } - } + // if (tile_idx == 0) + // first tile gets repeat of first item (thus first item will not + // be flagged as a head) + // else + // Subsequent tiles get last key from previous tile + tile_predecessor = (tile_idx == 0) ? keys[0] : d_keys_in[tile_offset - 1]; } + CTA_SYNC(); - /** - * 2-phase scatter flagged items to output offsets - * - * The exclusive scan causes each head flag to be paired with the previous - * value aggregate: the scatter offsets must be decremented for value aggregates - */ - __device__ __forceinline__ void ScatterTwoPhase( - KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], - OffsetT (&segment_flags)[ITEMS_PER_THREAD], - OffsetT (&segment_indices)[ITEMS_PER_THREAD], - OffsetT num_tile_segments, - OffsetT num_tile_segments_prefix) + // Load values + if (IS_LAST_TILE) { - CTA_SYNC(); - - // Compact and scatter pairs - #pragma unroll - for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) - { - if (segment_flags[ITEM]) - { - temp_storage.raw_exchange.Alias()[segment_indices[ITEM] - num_tile_segments_prefix] = scatter_items[ITEM]; - } - } - - CTA_SYNC(); - - for (int item = threadIdx.x; item < num_tile_segments; item += BLOCK_THREADS) - { - KeyValuePairT pair = temp_storage.raw_exchange.Alias()[item]; - d_unique_out[num_tile_segments_prefix + item] = pair.key; - d_aggregates_out[num_tile_segments_prefix + item] = pair.value; - } + BlockLoadValuesT(temp_storage.load_values) + .Load(d_values_in + tile_offset, values, num_remaining); } + else + { + BlockLoadValuesT(temp_storage.load_values) + .Load(d_values_in + tile_offset, values); + } + + CTA_SYNC(); + // Initialize head-flags and shuffle up the previous keys + if (IS_LAST_TILE) + { + // Use custom flag operator to additionally flag the first out-of-bounds + // item + GuardedInequalityWrapper flag_op(equality_op, num_remaining); + BlockDiscontinuityKeys(temp_storage.scan_storage.discontinuity) + .FlagHeads(head_flags, keys, prev_keys, flag_op, tile_predecessor); + } + else + { + InequalityWrapper flag_op(equality_op); + BlockDiscontinuityKeys(temp_storage.scan_storage.discontinuity) + .FlagHeads(head_flags, keys, prev_keys, flag_op, tile_predecessor); + } - /** - * Scatter flagged items - */ - __device__ __forceinline__ void Scatter( - KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], - OffsetT (&segment_flags)[ITEMS_PER_THREAD], - OffsetT (&segment_indices)[ITEMS_PER_THREAD], - OffsetT num_tile_segments, - OffsetT num_tile_segments_prefix) +// Zip values and head flags +#pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { - // Do a one-phase scatter if (a) two-phase is disabled or (b) the average number of selected items per thread is less than one - if (TWO_PHASE_SCATTER && (num_tile_segments > BLOCK_THREADS)) - { - ScatterTwoPhase( - scatter_items, - segment_flags, - segment_indices, - num_tile_segments, - num_tile_segments_prefix); - } - else - { - ScatterDirect( - scatter_items, - segment_flags, - segment_indices); - } + scan_items[ITEM].value = values[ITEM]; + scan_items[ITEM].key = head_flags[ITEM]; } + // Perform exclusive tile scan + // Inclusive block-wide scan aggregate + OffsetValuePairT block_aggregate; - //--------------------------------------------------------------------- - // Cooperatively scan a device-wide sequence of tiles with other CTAs - //--------------------------------------------------------------------- + // Number of segments prior to this tile + OffsetT num_segments_prefix; - /** - * Process a tile of input (dynamic chained scan) - */ - template ///< Whether the current tile is the last tile - __device__ __forceinline__ void ConsumeTile( - OffsetT num_remaining, ///< Number of global input items remaining (including this tile) - int tile_idx, ///< Tile index - OffsetT tile_offset, ///< Tile offset - ScanTileStateT& tile_state) ///< Global tile state descriptor + // The tile prefix folded with block_aggregate + OffsetValuePairT total_aggregate; + + if (tile_idx == 0) { - KeyOutputT keys[ITEMS_PER_THREAD]; // Tile keys - KeyOutputT prev_keys[ITEMS_PER_THREAD]; // Tile keys shuffled up - ValueOutputT values[ITEMS_PER_THREAD]; // Tile values - OffsetT head_flags[ITEMS_PER_THREAD]; // Segment head flags - OffsetT segment_indices[ITEMS_PER_THREAD]; // Segment indices - OffsetValuePairT scan_items[ITEMS_PER_THREAD]; // Zipped values and segment flags|indices - KeyValuePairT scatter_items[ITEMS_PER_THREAD]; // Zipped key value pairs for scattering - - // Load keys - if (IS_LAST_TILE) - BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys, num_remaining); - else - BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys); - - // Load tile predecessor key in first thread - KeyOutputT tile_predecessor; - if (threadIdx.x == 0) - { - tile_predecessor = (tile_idx == 0) ? - keys[0] : // First tile gets repeat of first item (thus first item will not be flagged as a head) - d_keys_in[tile_offset - 1]; // Subsequent tiles get last key from previous tile - } - - CTA_SYNC(); - - // Load values - if (IS_LAST_TILE) - BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values, num_remaining); - else - BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values); - - CTA_SYNC(); - - // Initialize head-flags and shuffle up the previous keys - if (IS_LAST_TILE) - { - // Use custom flag operator to additionally flag the first out-of-bounds item - GuardedInequalityWrapper flag_op(equality_op, num_remaining); - BlockDiscontinuityKeys(temp_storage.scan_storage.discontinuity).FlagHeads( - head_flags, keys, prev_keys, flag_op, tile_predecessor); - } - else - { - InequalityWrapper flag_op(equality_op); - BlockDiscontinuityKeys(temp_storage.scan_storage.discontinuity).FlagHeads( - head_flags, keys, prev_keys, flag_op, tile_predecessor); - } - - // Zip values and head flags - #pragma unroll - for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) - { - scan_items[ITEM].value = values[ITEM]; - scan_items[ITEM].key = head_flags[ITEM]; - } - - // Perform exclusive tile scan - OffsetValuePairT block_aggregate; // Inclusive block-wide scan aggregate - OffsetT num_segments_prefix; // Number of segments prior to this tile - OffsetValuePairT total_aggregate; // The tile prefix folded with block_aggregate - if (tile_idx == 0) - { - // Scan first tile - BlockScanT(temp_storage.scan_storage.scan).ExclusiveScan(scan_items, scan_items, scan_op, block_aggregate); - num_segments_prefix = 0; - total_aggregate = block_aggregate; - - // Update tile status if there are successor tiles - if ((!IS_LAST_TILE) && (threadIdx.x == 0)) - tile_state.SetInclusive(0, block_aggregate); - } - else - { - // Scan non-first tile - TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.scan_storage.prefix, scan_op, tile_idx); - BlockScanT(temp_storage.scan_storage.scan).ExclusiveScan(scan_items, scan_items, scan_op, prefix_op); - - block_aggregate = prefix_op.GetBlockAggregate(); - num_segments_prefix = prefix_op.GetExclusivePrefix().key; - total_aggregate = prefix_op.GetInclusivePrefix(); - } - - // Rezip scatter items and segment indices - #pragma unroll - for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) - { - scatter_items[ITEM].key = prev_keys[ITEM]; - scatter_items[ITEM].value = scan_items[ITEM].value; - segment_indices[ITEM] = scan_items[ITEM].key; - } - - // At this point, each flagged segment head has: - // - The key for the previous segment - // - The reduced value from the previous segment - // - The segment index for the reduced value - - // Scatter flagged keys and values - OffsetT num_tile_segments = block_aggregate.key; - Scatter(scatter_items, head_flags, segment_indices, num_tile_segments, num_segments_prefix); - - // Last thread in last tile will output final count (and last pair, if necessary) - if ((IS_LAST_TILE) && (threadIdx.x == BLOCK_THREADS - 1)) - { - OffsetT num_segments = num_segments_prefix + num_tile_segments; - - // If the last tile is a whole tile, output the final_value - if (num_remaining == TILE_ITEMS) - { - d_unique_out[num_segments] = keys[ITEMS_PER_THREAD - 1]; - d_aggregates_out[num_segments] = total_aggregate.value; - num_segments++; - } - - // Output the total number of items selected - *d_num_runs_out = num_segments; - } + // Scan first tile + BlockScanT(temp_storage.scan_storage.scan) + .ExclusiveScan(scan_items, scan_items, scan_op, block_aggregate); + num_segments_prefix = 0; + total_aggregate = block_aggregate; + + // Update tile status if there are successor tiles + if ((!IS_LAST_TILE) && (threadIdx.x == 0)) + { + tile_state.SetInclusive(0, block_aggregate); + } + } + else + { + // Scan non-first tile + TilePrefixCallbackOpT prefix_op(tile_state, + temp_storage.scan_storage.prefix, + scan_op, + tile_idx); + BlockScanT(temp_storage.scan_storage.scan) + .ExclusiveScan(scan_items, scan_items, scan_op, prefix_op); + + block_aggregate = prefix_op.GetBlockAggregate(); + num_segments_prefix = prefix_op.GetExclusivePrefix().key; + total_aggregate = prefix_op.GetInclusivePrefix(); } - - /** - * Scan tiles of items as part of a dynamic chained scan - */ - __device__ __forceinline__ void ConsumeRange( - OffsetT num_items, ///< Total number of input items - ScanTileStateT& tile_state, ///< Global tile state descriptor - int start_tile) ///< The starting tile for the current grid +// Rezip scatter items and segment indices +#pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { - // Blocks are launched in increasing order, so just assign one tile per block - int tile_idx = start_tile + blockIdx.x; // Current tile index - OffsetT tile_offset = OffsetT(TILE_ITEMS) * tile_idx; // Global offset for the current tile - OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile) - - if (num_remaining > TILE_ITEMS) - { - // Not last tile - ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); - } - else if (num_remaining > 0) - { - // Last tile - ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); - } + scatter_items[ITEM].key = prev_keys[ITEM]; + scatter_items[ITEM].value = scan_items[ITEM].value; + segment_indices[ITEM] = scan_items[ITEM].key; } + // At this point, each flagged segment head has: + // - The key for the previous segment + // - The reduced value from the previous segment + // - The segment index for the reduced value + + // Scatter flagged keys and values + OffsetT num_tile_segments = block_aggregate.key; + Scatter(scatter_items, + head_flags, + segment_indices, + num_tile_segments, + num_segments_prefix); + + // Last thread in last tile will output final count (and last pair, if + // necessary) + if ((IS_LAST_TILE) && (threadIdx.x == BLOCK_THREADS - 1)) + { + OffsetT num_segments = num_segments_prefix + num_tile_segments; + + // If the last tile is a whole tile, output the final_value + if (num_remaining == TILE_ITEMS) + { + d_unique_out[num_segments] = keys[ITEMS_PER_THREAD - 1]; + d_aggregates_out[num_segments] = total_aggregate.value; + num_segments++; + } + + // Output the total number of items selected + *d_num_runs_out = num_segments; + } + } + + /** + * @brief Scan tiles of items as part of a dynamic chained scan + * + * @param num_items + * Total number of input items + * + * @param tile_state + * Global tile state descriptor + * + * @param start_tile + * The starting tile for the current grid + */ + __device__ __forceinline__ void ConsumeRange(OffsetT num_items, + ScanTileStateT &tile_state, + int start_tile) + { + // Blocks are launched in increasing order, so just assign one tile per + // block + + // Current tile index + int tile_idx = start_tile + blockIdx.x; + + // Global offset for the current tile + OffsetT tile_offset = OffsetT(TILE_ITEMS) * tile_idx; + + // Remaining items (including this tile) + OffsetT num_remaining = num_items - tile_offset; + + if (num_remaining > TILE_ITEMS) + { + // Not last tile + ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); + } + else if (num_remaining > 0) + { + // Last tile + ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); + } + } }; - CUB_NAMESPACE_END diff --git a/cub/agent/agent_scan.cuh b/cub/agent/agent_scan.cuh index 3702b8c2dc..137e1baa32 100644 --- a/cub/agent/agent_scan.cuh +++ b/cub/agent/agent_scan.cuh @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. - * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -13,10 +13,10 @@ * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND @@ -27,464 +27,551 @@ ******************************************************************************/ /** - * \file - * cub::AgentScan implements a stateful abstraction of CUDA thread blocks for participating in device-wide prefix scan . + * @file cub::AgentScan implements a stateful abstraction of CUDA thread blocks + * for participating in device-wide prefix scan . */ #pragma once #include -#include "single_pass_scan_operators.cuh" -#include "../block/block_load.cuh" -#include "../block/block_store.cuh" -#include "../block/block_scan.cuh" -#include "../config.cuh" -#include "../grid/grid_queue.cuh" -#include "../iterator/cache_modified_input_iterator.cuh" +#include +#include +#include +#include +#include +#include +#include CUB_NAMESPACE_BEGIN - /****************************************************************************** * Tuning policy types ******************************************************************************/ /** - * Parameterizable tuning policy type for AgentScan + * @brief Parameterizable tuning policy type for AgentScan + * + * @tparam NOMINAL_BLOCK_THREADS_4B + * Threads per thread block + * + * @tparam NOMINAL_ITEMS_PER_THREAD_4B + * Items per thread (per tile of input) + * + * @tparam ComputeT + * Dominant compute type + * + * @tparam _LOAD_ALGORITHM + * The BlockLoad algorithm to use + * + * @tparam _LOAD_MODIFIER + * Cache load modifier for reading input elements + * + * @tparam _STORE_ALGORITHM + * The BlockStore algorithm to use + * + * @tparam _SCAN_ALGORITHM + * The BlockScan algorithm to use + * */ -template < - int NOMINAL_BLOCK_THREADS_4B, ///< Threads per thread block - int NOMINAL_ITEMS_PER_THREAD_4B, ///< Items per thread (per tile of input) - typename ComputeT, ///< Dominant compute type - BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use - CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements - BlockStoreAlgorithm _STORE_ALGORITHM, ///< The BlockStore algorithm to use - BlockScanAlgorithm _SCAN_ALGORITHM, ///< The BlockScan algorithm to use - typename ScalingType = MemBoundScaling > - -struct AgentScanPolicy : - ScalingType +template > +struct AgentScanPolicy : ScalingType { - static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use - static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements - static const BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM; ///< The BlockStore algorithm to use - static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use + static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; + static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; + static const BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM; + static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; }; - - - /****************************************************************************** * Thread block abstractions ******************************************************************************/ /** - * \brief AgentScan implements a stateful abstraction of CUDA thread blocks for participating in device-wide prefix scan . + * @brief AgentScan implements a stateful abstraction of CUDA thread blocks for + * participating in device-wide prefix scan. + * @tparam AgentScanPolicyT + * Parameterized AgentScanPolicyT tuning policy type + * + * @tparam InputIteratorT + * Random-access input iterator type + * + * @tparam OutputIteratorT + * Random-access output iterator type + * + * @tparam ScanOpT + * Scan functor type + * + * @tparam InitValueT + * The init_value element for ScanOpT type (cub::NullType for inclusive scan) + * + * @tparam OffsetT + * Signed integer type for global offsets + * */ -template < - typename AgentScanPolicyT, ///< Parameterized AgentScanPolicyT tuning policy type - typename InputIteratorT, ///< Random-access input iterator type - typename OutputIteratorT, ///< Random-access output iterator type - typename ScanOpT, ///< Scan functor type - typename InitValueT, ///< The init_value element for ScanOpT type (cub::NullType for inclusive scan) - typename OffsetT> ///< Signed integer type for global offsets +template struct AgentScan { - //--------------------------------------------------------------------- - // Types and constants - //--------------------------------------------------------------------- - - // The input value type - using InputT = cub::detail::value_t; - - // The output value type -- used as the intermediate accumulator - // Per https://wg21.link/P0571, use InitValueT if provided, otherwise the - // input iterator's value type. - using OutputT = - cub::detail::conditional_t::value, - InputT, - InitValueT>; - - // Tile status descriptor interface type - using ScanTileStateT = ScanTileState; - - // Input iterator wrapper type (for applying cache modifier) - // Wrap the native input pointer with CacheModifiedInputIterator - // or directly use the supplied input iterator type - using WrappedInputIteratorT = cub::detail::conditional_t< - std::is_pointer::value, - CacheModifiedInputIterator, - InputIteratorT>; - - // Constants - enum + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + // The input value type + using InputT = cub::detail::value_t; + + // Tile status descriptor interface type + using ScanTileStateT = ScanTileState; + + // Input iterator wrapper type (for applying cache modifier) + // Wrap the native input pointer with CacheModifiedInputIterator + // or directly use the supplied input iterator type + using WrappedInputIteratorT = cub::detail::conditional_t< + std::is_pointer::value, + CacheModifiedInputIterator, + InputIteratorT>; + + // Constants + enum + { + // Inclusive scan if no init_value type is provided + IS_INCLUSIVE = std::is_same::value, + BLOCK_THREADS = AgentScanPolicyT::BLOCK_THREADS, + ITEMS_PER_THREAD = AgentScanPolicyT::ITEMS_PER_THREAD, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + }; + + // Parameterized BlockLoad type + typedef BlockLoad + BlockLoadT; + + // Parameterized BlockStore type + typedef BlockStore + BlockStoreT; + + // Parameterized BlockScan type + typedef BlockScan + BlockScanT; + + // Callback type for obtaining tile prefix during block scan + typedef TilePrefixCallbackOp + TilePrefixCallbackOpT; + + // Stateful BlockScan prefix callback type for managing a running total while + // scanning consecutive tiles + typedef BlockScanRunningPrefixOp RunningPrefixCallbackOp; + + // Shared memory type for this thread block + union _TempStorage + { + // Smem needed for tile loading + typename BlockLoadT::TempStorage load; + + // Smem needed for tile storing + typename BlockStoreT::TempStorage store; + + struct ScanStorage { - // Inclusive scan if no init_value type is provided - IS_INCLUSIVE = std::is_same::value, - BLOCK_THREADS = AgentScanPolicyT::BLOCK_THREADS, - ITEMS_PER_THREAD = AgentScanPolicyT::ITEMS_PER_THREAD, - TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, - }; - - // Parameterized BlockLoad type - typedef BlockLoad< - OutputT, - AgentScanPolicyT::BLOCK_THREADS, - AgentScanPolicyT::ITEMS_PER_THREAD, - AgentScanPolicyT::LOAD_ALGORITHM> - BlockLoadT; - - // Parameterized BlockStore type - typedef BlockStore< - OutputT, - AgentScanPolicyT::BLOCK_THREADS, - AgentScanPolicyT::ITEMS_PER_THREAD, - AgentScanPolicyT::STORE_ALGORITHM> - BlockStoreT; - - // Parameterized BlockScan type - typedef BlockScan< - OutputT, - AgentScanPolicyT::BLOCK_THREADS, - AgentScanPolicyT::SCAN_ALGORITHM> - BlockScanT; - - // Callback type for obtaining tile prefix during block scan - typedef TilePrefixCallbackOp< - OutputT, - ScanOpT, - ScanTileStateT> - TilePrefixCallbackOpT; - - // Stateful BlockScan prefix callback type for managing a running total while scanning consecutive tiles - typedef BlockScanRunningPrefixOp< - OutputT, - ScanOpT> - RunningPrefixCallbackOp; - - // Shared memory type for this thread block - union _TempStorage + // Smem needed for cooperative prefix callback + typename TilePrefixCallbackOpT::TempStorage prefix; + + // Smem needed for tile scanning + typename BlockScanT::TempStorage scan; + } scan_storage; + }; + + // Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> + {}; + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + _TempStorage &temp_storage; ///< Reference to temp_storage + WrappedInputIteratorT d_in; ///< Input data + OutputIteratorT d_out; ///< Output data + ScanOpT scan_op; ///< Binary scan operator + InitValueT init_value; ///< The init_value element for ScanOpT + + //--------------------------------------------------------------------- + // Block scan utility methods + //--------------------------------------------------------------------- + + /** + * Exclusive scan specialization (first tile) + */ + __device__ __forceinline__ void ScanTile(AccumT (&items)[ITEMS_PER_THREAD], + AccumT init_value, + ScanOpT scan_op, + AccumT &block_aggregate, + Int2Type /*is_inclusive*/) + { + BlockScanT(temp_storage.scan_storage.scan) + .ExclusiveScan(items, items, init_value, scan_op, block_aggregate); + block_aggregate = scan_op(init_value, block_aggregate); + } + + /** + * Inclusive scan specialization (first tile) + */ + __device__ __forceinline__ void ScanTile(AccumT (&items)[ITEMS_PER_THREAD], + InitValueT /*init_value*/, + ScanOpT scan_op, + AccumT &block_aggregate, + Int2Type /*is_inclusive*/) + { + BlockScanT(temp_storage.scan_storage.scan) + .InclusiveScan(items, items, scan_op, block_aggregate); + } + + /** + * Exclusive scan specialization (subsequent tiles) + */ + template + __device__ __forceinline__ void ScanTile(AccumT (&items)[ITEMS_PER_THREAD], + ScanOpT scan_op, + PrefixCallback &prefix_op, + Int2Type /*is_inclusive*/) + { + BlockScanT(temp_storage.scan_storage.scan) + .ExclusiveScan(items, items, scan_op, prefix_op); + } + + /** + * Inclusive scan specialization (subsequent tiles) + */ + template + __device__ __forceinline__ void ScanTile(AccumT (&items)[ITEMS_PER_THREAD], + ScanOpT scan_op, + PrefixCallback &prefix_op, + Int2Type /*is_inclusive*/) + { + BlockScanT(temp_storage.scan_storage.scan) + .InclusiveScan(items, items, scan_op, prefix_op); + } + + //--------------------------------------------------------------------- + // Constructor + //--------------------------------------------------------------------- + + /** + * @param temp_storage + * Reference to temp_storage + * + * @param d_in + * Input data + * + * @param d_out + * Output data + * + * @param scan_op + * Binary scan operator + * + * @param init_value + * Initial value to seed the exclusive scan + */ + __device__ __forceinline__ AgentScan(TempStorage &temp_storage, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + InitValueT init_value) + : temp_storage(temp_storage.Alias()) + , d_in(d_in) + , d_out(d_out) + , scan_op(scan_op) + , init_value(init_value) + {} + + //--------------------------------------------------------------------- + // Cooperatively scan a device-wide sequence of tiles with other CTAs + //--------------------------------------------------------------------- + + /** + * Process a tile of input (dynamic chained scan) + * @tparam IS_LAST_TILE + * Whether the current tile is the last tile + * + * @param num_remaining + * Number of global input items remaining (including this tile) + * + * @param tile_idx + * Tile index + * + * @param tile_offset + * Tile offset + * + * @param tile_state + * Global tile state descriptor + */ + template + __device__ __forceinline__ void ConsumeTile(OffsetT num_remaining, + int tile_idx, + OffsetT tile_offset, + ScanTileStateT &tile_state) + { + // Load items + AccumT items[ITEMS_PER_THREAD]; + + if (IS_LAST_TILE) { - typename BlockLoadT::TempStorage load; // Smem needed for tile loading - typename BlockStoreT::TempStorage store; // Smem needed for tile storing - - struct ScanStorage - { - typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback - typename BlockScanT::TempStorage scan; // Smem needed for tile scanning - } scan_storage; - }; - - // Alias wrapper allowing storage to be unioned - struct TempStorage : Uninitialized<_TempStorage> {}; - - - //--------------------------------------------------------------------- - // Per-thread fields - //--------------------------------------------------------------------- - - _TempStorage& temp_storage; ///< Reference to temp_storage - WrappedInputIteratorT d_in; ///< Input data - OutputIteratorT d_out; ///< Output data - ScanOpT scan_op; ///< Binary scan operator - InitValueT init_value; ///< The init_value element for ScanOpT - - - //--------------------------------------------------------------------- - // Block scan utility methods - //--------------------------------------------------------------------- - - /** - * Exclusive scan specialization (first tile) - */ - __device__ __forceinline__ - void ScanTile( - OutputT (&items)[ITEMS_PER_THREAD], - OutputT init_value, - ScanOpT scan_op, - OutputT &block_aggregate, - Int2Type /*is_inclusive*/) + // Fill last element with the first element because collectives are + // not suffix guarded. + BlockLoadT(temp_storage.load) + .Load(d_in + tile_offset, items, num_remaining, *(d_in + tile_offset)); + } + else { - BlockScanT(temp_storage.scan_storage.scan).ExclusiveScan(items, items, init_value, scan_op, block_aggregate); - block_aggregate = scan_op(init_value, block_aggregate); + BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items); } + CTA_SYNC(); - /** - * Inclusive scan specialization (first tile) - */ - __device__ __forceinline__ - void ScanTile( - OutputT (&items)[ITEMS_PER_THREAD], - InitValueT /*init_value*/, - ScanOpT scan_op, - OutputT &block_aggregate, - Int2Type /*is_inclusive*/) + // Perform tile scan + if (tile_idx == 0) { - BlockScanT(temp_storage.scan_storage.scan).InclusiveScan(items, items, scan_op, block_aggregate); + // Scan first tile + AccumT block_aggregate; + ScanTile(items, + init_value, + scan_op, + block_aggregate, + Int2Type()); + + if ((!IS_LAST_TILE) && (threadIdx.x == 0)) + { + tile_state.SetInclusive(0, block_aggregate); + } } - - - /** - * Exclusive scan specialization (subsequent tiles) - */ - template - __device__ __forceinline__ - void ScanTile( - OutputT (&items)[ITEMS_PER_THREAD], - ScanOpT scan_op, - PrefixCallback &prefix_op, - Int2Type /*is_inclusive*/) + else { - BlockScanT(temp_storage.scan_storage.scan).ExclusiveScan(items, items, scan_op, prefix_op); + // Scan non-first tile + TilePrefixCallbackOpT prefix_op(tile_state, + temp_storage.scan_storage.prefix, + scan_op, + tile_idx); + ScanTile(items, scan_op, prefix_op, Int2Type()); } + CTA_SYNC(); - /** - * Inclusive scan specialization (subsequent tiles) - */ - template - __device__ __forceinline__ - void ScanTile( - OutputT (&items)[ITEMS_PER_THREAD], - ScanOpT scan_op, - PrefixCallback &prefix_op, - Int2Type /*is_inclusive*/) + // Store items + if (IS_LAST_TILE) { - BlockScanT(temp_storage.scan_storage.scan).InclusiveScan(items, items, scan_op, prefix_op); + BlockStoreT(temp_storage.store) + .Store(d_out + tile_offset, items, num_remaining); } - - - //--------------------------------------------------------------------- - // Constructor - //--------------------------------------------------------------------- - - // Constructor - __device__ __forceinline__ - AgentScan( - TempStorage& temp_storage, ///< Reference to temp_storage - InputIteratorT d_in, ///< Input data - OutputIteratorT d_out, ///< Output data - ScanOpT scan_op, ///< Binary scan operator - InitValueT init_value) ///< Initial value to seed the exclusive scan - : - temp_storage(temp_storage.Alias()), - d_in(d_in), - d_out(d_out), - scan_op(scan_op), - init_value(init_value) - {} - - - //--------------------------------------------------------------------- - // Cooperatively scan a device-wide sequence of tiles with other CTAs - //--------------------------------------------------------------------- - - /** - * Process a tile of input (dynamic chained scan) - */ - template ///< Whether the current tile is the last tile - __device__ __forceinline__ void ConsumeTile( - OffsetT num_remaining, ///< Number of global input items remaining (including this tile) - int tile_idx, ///< Tile index - OffsetT tile_offset, ///< Tile offset - ScanTileStateT& tile_state) ///< Global tile state descriptor + else { - // Load items - OutputT items[ITEMS_PER_THREAD]; - - if (IS_LAST_TILE) - { - // Fill last element with the first element because collectives are - // not suffix guarded. - BlockLoadT(temp_storage.load) - .Load(d_in + tile_offset, - items, - num_remaining, - *(d_in + tile_offset)); - } - else - { - BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items); - } - - CTA_SYNC(); - - // Perform tile scan - if (tile_idx == 0) - { - // Scan first tile - OutputT block_aggregate; - ScanTile(items, init_value, scan_op, block_aggregate, Int2Type()); - if ((!IS_LAST_TILE) && (threadIdx.x == 0)) - tile_state.SetInclusive(0, block_aggregate); - } - else - { - // Scan non-first tile - TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.scan_storage.prefix, scan_op, tile_idx); - ScanTile(items, scan_op, prefix_op, Int2Type()); - } - - CTA_SYNC(); - - // Store items - if (IS_LAST_TILE) - BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items, num_remaining); - else - BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items); + BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items); } - - - /** - * Scan tiles of items as part of a dynamic chained scan - */ - __device__ __forceinline__ void ConsumeRange( - OffsetT num_items, ///< Total number of input items - ScanTileStateT& tile_state, ///< Global tile state descriptor - int start_tile) ///< The starting tile for the current grid + } + + /** + * @brief Scan tiles of items as part of a dynamic chained scan + * + * @param num_items + * Total number of input items + * + * @param tile_state + * Global tile state descriptor + * + * @param start_tile + * The starting tile for the current grid + */ + __device__ __forceinline__ void ConsumeRange(OffsetT num_items, + ScanTileStateT &tile_state, + int start_tile) + { + // Blocks are launched in increasing order, so just assign one tile per + // block + + // Current tile index + int tile_idx = start_tile + blockIdx.x; + + // Global offset for the current tile + OffsetT tile_offset = OffsetT(TILE_ITEMS) * tile_idx; + + // Remaining items (including this tile) + OffsetT num_remaining = num_items - tile_offset; + + if (num_remaining > TILE_ITEMS) { - // Blocks are launched in increasing order, so just assign one tile per block - int tile_idx = start_tile + blockIdx.x; // Current tile index - OffsetT tile_offset = OffsetT(TILE_ITEMS) * tile_idx; // Global offset for the current tile - OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile) - - if (num_remaining > TILE_ITEMS) - { - // Not last tile - ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); - } - else if (num_remaining > 0) - { - // Last tile - ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); - } + // Not last tile + ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); } - - - //--------------------------------------------------------------------- - // Scan an sequence of consecutive tiles (independent of other thread blocks) - //--------------------------------------------------------------------- - - /** - * Process a tile of input - */ - template < - bool IS_FIRST_TILE, - bool IS_LAST_TILE> - __device__ __forceinline__ void ConsumeTile( - OffsetT tile_offset, ///< Tile offset - RunningPrefixCallbackOp& prefix_op, ///< Running prefix operator - int valid_items = TILE_ITEMS) ///< Number of valid items in the tile + else if (num_remaining > 0) { - // Load items - OutputT items[ITEMS_PER_THREAD]; - - if (IS_LAST_TILE) - { - // Fill last element with the first element because collectives are - // not suffix guarded. - BlockLoadT(temp_storage.load) - .Load(d_in + tile_offset, - items, - valid_items, - *(d_in + tile_offset)); - } - else - { - BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items); - } - - CTA_SYNC(); - - // Block scan - if (IS_FIRST_TILE) - { - OutputT block_aggregate; - ScanTile(items, init_value, scan_op, block_aggregate, Int2Type()); - prefix_op.running_total = block_aggregate; - } - else - { - ScanTile(items, scan_op, prefix_op, Int2Type()); - } - - CTA_SYNC(); - - // Store items - if (IS_LAST_TILE) - BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items, valid_items); - else - BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items); + // Last tile + ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); + } + } + + //--------------------------------------------------------------------------- + // Scan an sequence of consecutive tiles (independent of other thread blocks) + //--------------------------------------------------------------------------- + + /** + * @brief Process a tile of input + * + * @param tile_offset + * Tile offset + * + * @param prefix_op + * Running prefix operator + * + * @param valid_items + * Number of valid items in the tile + */ + template + __device__ __forceinline__ void ConsumeTile(OffsetT tile_offset, + RunningPrefixCallbackOp &prefix_op, + int valid_items = TILE_ITEMS) + { + // Load items + AccumT items[ITEMS_PER_THREAD]; + + if (IS_LAST_TILE) + { + // Fill last element with the first element because collectives are + // not suffix guarded. + BlockLoadT(temp_storage.load) + .Load(d_in + tile_offset, items, valid_items, *(d_in + tile_offset)); + } + else + { + BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items); } + CTA_SYNC(); - /** - * Scan a consecutive share of input tiles - */ - __device__ __forceinline__ void ConsumeRange( - OffsetT range_offset, ///< [in] Threadblock begin offset (inclusive) - OffsetT range_end) ///< [in] Threadblock end offset (exclusive) + // Block scan + if (IS_FIRST_TILE) + { + AccumT block_aggregate; + ScanTile(items, + init_value, + scan_op, + block_aggregate, + Int2Type()); + prefix_op.running_total = block_aggregate; + } + else { - BlockScanRunningPrefixOp prefix_op(scan_op); - - if (range_offset + TILE_ITEMS <= range_end) - { - // Consume first tile of input (full) - ConsumeTile(range_offset, prefix_op); - range_offset += TILE_ITEMS; - - // Consume subsequent full tiles of input - while (range_offset + TILE_ITEMS <= range_end) - { - ConsumeTile(range_offset, prefix_op); - range_offset += TILE_ITEMS; - } - - // Consume a partially-full tile - if (range_offset < range_end) - { - int valid_items = range_end - range_offset; - ConsumeTile(range_offset, prefix_op, valid_items); - } - } - else - { - // Consume the first tile of input (partially-full) - int valid_items = range_end - range_offset; - ConsumeTile(range_offset, prefix_op, valid_items); - } + ScanTile(items, scan_op, prefix_op, Int2Type()); } + CTA_SYNC(); - /** - * Scan a consecutive share of input tiles, seeded with the specified prefix value - */ - __device__ __forceinline__ void ConsumeRange( - OffsetT range_offset, ///< [in] Threadblock begin offset (inclusive) - OffsetT range_end, ///< [in] Threadblock end offset (exclusive) - OutputT prefix) ///< [in] The prefix to apply to the scan segment + // Store items + if (IS_LAST_TILE) { - BlockScanRunningPrefixOp prefix_op(prefix, scan_op); - - // Consume full tiles of input - while (range_offset + TILE_ITEMS <= range_end) - { - ConsumeTile(range_offset, prefix_op); - range_offset += TILE_ITEMS; - } - - // Consume a partially-full tile - if (range_offset < range_end) - { - int valid_items = range_end - range_offset; - ConsumeTile(range_offset, prefix_op, valid_items); - } + BlockStoreT(temp_storage.store) + .Store(d_out + tile_offset, items, valid_items); + } + else + { + BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items); + } + } + + /** + * @brief Scan a consecutive share of input tiles + * + * @param[in] range_offset + * Threadblock begin offset (inclusive) + * + * @param[in] range_end + * Threadblock end offset (exclusive) + */ + __device__ __forceinline__ void ConsumeRange(OffsetT range_offset, + OffsetT range_end) + { + BlockScanRunningPrefixOp prefix_op(scan_op); + + if (range_offset + TILE_ITEMS <= range_end) + { + // Consume first tile of input (full) + ConsumeTile(range_offset, prefix_op); + range_offset += TILE_ITEMS; + + // Consume subsequent full tiles of input + while (range_offset + TILE_ITEMS <= range_end) + { + ConsumeTile(range_offset, prefix_op); + range_offset += TILE_ITEMS; + } + + // Consume a partially-full tile + if (range_offset < range_end) + { + int valid_items = range_end - range_offset; + ConsumeTile(range_offset, prefix_op, valid_items); + } + } + else + { + // Consume the first tile of input (partially-full) + int valid_items = range_end - range_offset; + ConsumeTile(range_offset, prefix_op, valid_items); + } + } + + /** + * @brief Scan a consecutive share of input tiles, seeded with the + * specified prefix value + * @param[in] range_offset + * Threadblock begin offset (inclusive) + * + * @param[in] range_end + * Threadblock end offset (exclusive) + * + * @param[in] prefix + * The prefix to apply to the scan segment + */ + __device__ __forceinline__ void ConsumeRange(OffsetT range_offset, + OffsetT range_end, + AccumT prefix) + { + BlockScanRunningPrefixOp prefix_op(prefix, scan_op); + + // Consume full tiles of input + while (range_offset + TILE_ITEMS <= range_end) + { + ConsumeTile(range_offset, prefix_op); + range_offset += TILE_ITEMS; } + // Consume a partially-full tile + if (range_offset < range_end) + { + int valid_items = range_end - range_offset; + ConsumeTile(range_offset, prefix_op, valid_items); + } + } }; - CUB_NAMESPACE_END diff --git a/cub/agent/agent_scan_by_key.cuh b/cub/agent/agent_scan_by_key.cuh index b7552cc38a..ca6c7db5ac 100644 --- a/cub/agent/agent_scan_by_key.cuh +++ b/cub/agent/agent_scan_by_key.cuh @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -12,10 +12,10 @@ * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND @@ -26,8 +26,8 @@ ******************************************************************************/ /** - * \file - * AgentScanByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide prefix scan by key. + * @file AgentScanByKey implements a stateful abstraction of CUDA thread blocks + * for participating in device-wide prefix scan by key. */ #pragma once @@ -43,10 +43,8 @@ #include - CUB_NAMESPACE_BEGIN - /****************************************************************************** * Tuning policy types ******************************************************************************/ @@ -54,422 +52,453 @@ CUB_NAMESPACE_BEGIN /** * Parameterizable tuning policy type for AgentScanByKey */ - -template +template struct AgentScanByKeyPolicy { - enum - { - BLOCK_THREADS = _BLOCK_THREADS, - ITEMS_PER_THREAD = _ITEMS_PER_THREAD, - }; - - static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; - static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; - static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; - static const BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM; -}; + static constexpr int BLOCK_THREADS = _BLOCK_THREADS; + static constexpr int ITEMS_PER_THREAD = _ITEMS_PER_THREAD; + static constexpr BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; + static constexpr CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; + static constexpr BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; + static constexpr BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM; +}; /****************************************************************************** * Thread block abstractions ******************************************************************************/ /** - * \brief AgentScanByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide prefix scan by key. + * @brief AgentScanByKey implements a stateful abstraction of CUDA thread + * blocks for participating in device-wide prefix scan by key. + * + * @tparam AgentScanByKeyPolicyT + * Parameterized AgentScanPolicyT tuning policy type + * + * @tparam KeysInputIteratorT + * Random-access input iterator type + * + * @tparam ValuesInputIteratorT + * Random-access input iterator type + * + * @tparam ValuesOutputIteratorT + * Random-access output iterator type + * + * @tparam EqualityOp + * Equality functor type + * + * @tparam ScanOpT + * Scan functor type + * + * @tparam InitValueT + * The init_value element for ScanOpT type (cub::NullType for inclusive scan) + * + * @tparam OffsetT + * Signed integer type for global offsets + * */ -template < - typename AgentScanByKeyPolicyT, ///< Parameterized AgentScanPolicyT tuning policy type - typename KeysInputIteratorT, ///< Random-access input iterator type - typename ValuesInputIteratorT, ///< Random-access input iterator type - typename ValuesOutputIteratorT, ///< Random-access output iterator type - typename EqualityOp, ///< Equality functor type - typename ScanOpT, ///< Scan functor type - typename InitValueT, ///< The init_value element for ScanOpT type (cub::NullType for inclusive scan) - typename OffsetT> ///< Signed integer type for global offsets +template struct AgentScanByKey { - //--------------------------------------------------------------------- - // Types and constants - //--------------------------------------------------------------------- - - using KeyT = cub::detail::value_t; - using InputT = cub::detail::value_t; - - // The output value type -- used as the intermediate accumulator - // Per https://wg21.link/P0571, use InitValueT if provided, otherwise the - // input iterator's value type. - using OutputT = - cub::detail::conditional_t::value, - InputT, - InitValueT>; - - using SizeValuePairT = KeyValuePair; - using KeyValuePairT = KeyValuePair; - using ReduceBySegmentOpT = ReduceBySegmentOp; - - using ScanTileStateT = ReduceByKeyScanTileState; - - // Constants - enum + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + using KeyT = cub::detail::value_t; + using InputT = cub::detail::value_t; + using SizeValuePairT = KeyValuePair; + using KeyValuePairT = KeyValuePair; + using ReduceBySegmentOpT = ReduceBySegmentOp; + + using ScanTileStateT = ReduceByKeyScanTileState; + + // Constants + // Inclusive scan if no init_value type is provided + static constexpr int IS_INCLUSIVE = std::is_same::value; + static constexpr int BLOCK_THREADS = AgentScanByKeyPolicyT::BLOCK_THREADS; + static constexpr int ITEMS_PER_THREAD = + AgentScanByKeyPolicyT::ITEMS_PER_THREAD; + static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD; + + using WrappedKeysInputIteratorT = cub::detail::conditional_t< + std::is_pointer::value, + CacheModifiedInputIterator, + KeysInputIteratorT>; + + using WrappedValuesInputIteratorT = cub::detail::conditional_t< + std::is_pointer::value, + CacheModifiedInputIterator, + ValuesInputIteratorT>; + + using BlockLoadKeysT = BlockLoad; + + using BlockLoadValuesT = BlockLoad; + + using BlockStoreValuesT = BlockStore; + + using BlockDiscontinuityKeysT = BlockDiscontinuity; + + using TilePrefixCallbackT = + TilePrefixCallbackOp; + + using BlockScanT = BlockScan; + + union TempStorage_ + { + struct ScanStorage { - IS_INCLUSIVE = std::is_same::value, // Inclusive scan if no init_value type is provided - BLOCK_THREADS = AgentScanByKeyPolicyT::BLOCK_THREADS, - ITEMS_PER_THREAD = AgentScanByKeyPolicyT::ITEMS_PER_THREAD, - ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD, - }; - - using WrappedKeysInputIteratorT = cub::detail::conditional_t::value, - CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedInputIterator - KeysInputIteratorT>; - using WrappedValuesInputIteratorT = cub::detail::conditional_t::value, - CacheModifiedInputIterator, // Wrap the native input pointer with CacheModifiedInputIterator - ValuesInputIteratorT>; - - using BlockLoadKeysT = BlockLoad; - using BlockLoadValuesT = BlockLoad; - using BlockStoreValuesT = BlockStore; - using BlockDiscontinuityKeysT = BlockDiscontinuity; - - using TilePrefixCallbackT = TilePrefixCallbackOp; - using BlockScanT = BlockScan; - - union TempStorage_ + typename BlockScanT::TempStorage scan; + typename TilePrefixCallbackT::TempStorage prefix; + typename BlockDiscontinuityKeysT::TempStorage discontinuity; + } scan_storage; + + typename BlockLoadKeysT::TempStorage load_keys; + typename BlockLoadValuesT::TempStorage load_values; + typename BlockStoreValuesT::TempStorage store_values; + }; + + struct TempStorage : cub::Uninitialized + {}; + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + TempStorage_ &storage; + WrappedKeysInputIteratorT d_keys_in; + KeyT *d_keys_prev_in; + WrappedValuesInputIteratorT d_values_in; + ValuesOutputIteratorT d_values_out; + InequalityWrapper inequality_op; + ScanOpT scan_op; + ReduceBySegmentOpT pair_scan_op; + InitValueT init_value; + + //--------------------------------------------------------------------- + // Block scan utility methods (first tile) + //--------------------------------------------------------------------- + + // Exclusive scan specialization + __device__ __forceinline__ void + ScanTile(SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], + SizeValuePairT &tile_aggregate, + Int2Type /* is_inclusive */) + { + BlockScanT(storage.scan_storage.scan) + .ExclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate); + } + + // Inclusive scan specialization + __device__ __forceinline__ void + ScanTile(SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], + SizeValuePairT &tile_aggregate, + Int2Type /* is_inclusive */) + { + BlockScanT(storage.scan_storage.scan) + .InclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate); + } + + //--------------------------------------------------------------------- + // Block scan utility methods (subsequent tiles) + //--------------------------------------------------------------------- + + // Exclusive scan specialization (with prefix from predecessors) + __device__ __forceinline__ void + ScanTile(SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], + SizeValuePairT &tile_aggregate, + TilePrefixCallbackT &prefix_op, + Int2Type /* is_incclusive */) + { + BlockScanT(storage.scan_storage.scan) + .ExclusiveScan(scan_items, scan_items, pair_scan_op, prefix_op); + tile_aggregate = prefix_op.GetBlockAggregate(); + } + + // Inclusive scan specialization (with prefix from predecessors) + __device__ __forceinline__ void + ScanTile(SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], + SizeValuePairT &tile_aggregate, + TilePrefixCallbackT &prefix_op, + Int2Type /* is_inclusive */) + { + BlockScanT(storage.scan_storage.scan) + .InclusiveScan(scan_items, scan_items, pair_scan_op, prefix_op); + tile_aggregate = prefix_op.GetBlockAggregate(); + } + + //--------------------------------------------------------------------- + // Zip utility methods + //--------------------------------------------------------------------- + + template + __device__ __forceinline__ void + ZipValuesAndFlags(OffsetT num_remaining, + AccumT (&values)[ITEMS_PER_THREAD], + OffsetT (&segment_flags)[ITEMS_PER_THREAD], + SizeValuePairT (&scan_items)[ITEMS_PER_THREAD]) + { +// Zip values and segment_flags +#pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { - struct ScanStorage - { - typename BlockScanT::TempStorage scan; - typename TilePrefixCallbackT::TempStorage prefix; - typename BlockDiscontinuityKeysT::TempStorage discontinuity; - } scan_storage; - - typename BlockLoadKeysT::TempStorage load_keys; - typename BlockLoadValuesT::TempStorage load_values; - typename BlockStoreValuesT::TempStorage store_values; - }; - - struct TempStorage : cub::Uninitialized {}; - - //--------------------------------------------------------------------- - // Per-thread fields - //--------------------------------------------------------------------- - - TempStorage_ &storage; - WrappedKeysInputIteratorT d_keys_in; - KeyT* d_keys_prev_in; - WrappedValuesInputIteratorT d_values_in; - ValuesOutputIteratorT d_values_out; - InequalityWrapper inequality_op; - ScanOpT scan_op; - ReduceBySegmentOpT pair_scan_op; - InitValueT init_value; - - //--------------------------------------------------------------------- - // Block scan utility methods (first tile) - //--------------------------------------------------------------------- - - // Exclusive scan specialization - __device__ __forceinline__ - void ScanTile( - SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], - SizeValuePairT &tile_aggregate, - Int2Type /* is_inclusive */) - { - BlockScanT(storage.scan_storage.scan) - .ExclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate); + // Set segment_flags for first out-of-bounds item, zero for others + if (IS_LAST_TILE && + OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM == num_remaining) + { + segment_flags[ITEM] = 1; + } + + scan_items[ITEM].value = values[ITEM]; + scan_items[ITEM].key = segment_flags[ITEM]; } - - // Inclusive scan specialization - __device__ __forceinline__ - void ScanTile( - SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], - SizeValuePairT &tile_aggregate, - Int2Type /* is_inclusive */) + } + + __device__ __forceinline__ void + UnzipValues(AccumT (&values)[ITEMS_PER_THREAD], + SizeValuePairT (&scan_items)[ITEMS_PER_THREAD]) + { +// Zip values and segment_flags +#pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { - BlockScanT(storage.scan_storage.scan) - .InclusiveScan(scan_items, scan_items, pair_scan_op, tile_aggregate); + values[ITEM] = scan_items[ITEM].value; } - - //--------------------------------------------------------------------- - // Block scan utility methods (subsequent tiles) - //--------------------------------------------------------------------- - - // Exclusive scan specialization (with prefix from predecessors) - __device__ __forceinline__ - void ScanTile( - SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], - SizeValuePairT & tile_aggregate, - TilePrefixCallbackT &prefix_op, - Int2Type /* is_incclusive */) + } + + template ::value, + typename std::enable_if::type = 0> + __device__ __forceinline__ void + AddInitToScan(AccumT (&items)[ITEMS_PER_THREAD], + OffsetT (&flags)[ITEMS_PER_THREAD]) + { +#pragma unroll + for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { - BlockScanT(storage.scan_storage.scan) - .ExclusiveScan(scan_items, scan_items, pair_scan_op, prefix_op); - tile_aggregate = prefix_op.GetBlockAggregate(); + items[ITEM] = flags[ITEM] ? init_value : scan_op(init_value, items[ITEM]); } - - // Inclusive scan specialization (with prefix from predecessors) - __device__ __forceinline__ - void ScanTile( - SizeValuePairT (&scan_items)[ITEMS_PER_THREAD], - SizeValuePairT & tile_aggregate, - TilePrefixCallbackT &prefix_op, - Int2Type /* is_inclusive */) + } + + template ::value, + typename std::enable_if::type = 0> + __device__ __forceinline__ void + AddInitToScan(AccumT (&/*items*/)[ITEMS_PER_THREAD], + OffsetT (&/*flags*/)[ITEMS_PER_THREAD]) + {} + + //--------------------------------------------------------------------- + // Cooperatively scan a device-wide sequence of tiles with other CTAs + //--------------------------------------------------------------------- + + // Process a tile of input (dynamic chained scan) + // + template + __device__ __forceinline__ void ConsumeTile(OffsetT /*num_items*/, + OffsetT num_remaining, + int tile_idx, + OffsetT tile_base, + ScanTileStateT &tile_state) + { + // Load items + KeyT keys[ITEMS_PER_THREAD]; + AccumT values[ITEMS_PER_THREAD]; + OffsetT segment_flags[ITEMS_PER_THREAD]; + SizeValuePairT scan_items[ITEMS_PER_THREAD]; + + if (IS_LAST_TILE) { - BlockScanT(storage.scan_storage.scan) - .InclusiveScan(scan_items, scan_items, pair_scan_op, prefix_op); - tile_aggregate = prefix_op.GetBlockAggregate(); + // Fill last element with the first element + // because collectives are not suffix guarded + BlockLoadKeysT(storage.load_keys) + .Load(d_keys_in + tile_base, + keys, + num_remaining, + *(d_keys_in + tile_base)); } - - //--------------------------------------------------------------------- - // Zip utility methods - //--------------------------------------------------------------------- - - template - __device__ __forceinline__ - void ZipValuesAndFlags( - OffsetT num_remaining, - OutputT (&values)[ITEMS_PER_THREAD], - OffsetT (&segment_flags)[ITEMS_PER_THREAD], - SizeValuePairT (&scan_items)[ITEMS_PER_THREAD]) + else { - // Zip values and segment_flags - #pragma unroll - for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) - { - // Set segment_flags for first out-of-bounds item, zero for others - if (IS_LAST_TILE && - OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM == num_remaining) - segment_flags[ITEM] = 1; - - scan_items[ITEM].value = values[ITEM]; - scan_items[ITEM].key = segment_flags[ITEM]; - } + BlockLoadKeysT(storage.load_keys).Load(d_keys_in + tile_base, keys); } - __device__ __forceinline__ - void UnzipValues( - OutputT (&values)[ITEMS_PER_THREAD], - SizeValuePairT (&scan_items)[ITEMS_PER_THREAD]) - { - // Zip values and segment_flags - #pragma unroll - for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) - { - values[ITEM] = scan_items[ITEM].value; - } - } + CTA_SYNC(); - template ::value, - typename std::enable_if::type = 0> - __device__ __forceinline__ void AddInitToScan( - OutputT (&items)[ITEMS_PER_THREAD], - OffsetT (&flags)[ITEMS_PER_THREAD]) + if (IS_LAST_TILE) { - #pragma unroll - for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) - { - items[ITEM] = flags[ITEM] ? init_value : scan_op(init_value, items[ITEM]); - } + // Fill last element with the first element + // because collectives are not suffix guarded + BlockLoadValuesT(storage.load_values) + .Load(d_values_in + tile_base, + values, + num_remaining, + *(d_values_in + tile_base)); } - - template ::value, - typename std::enable_if::type = 0> - __device__ __forceinline__ - void AddInitToScan( - OutputT (&/*items*/)[ITEMS_PER_THREAD], - OffsetT (&/*flags*/)[ITEMS_PER_THREAD]) - {} - - //--------------------------------------------------------------------- - // Cooperatively scan a device-wide sequence of tiles with other CTAs - //--------------------------------------------------------------------- - - // Process a tile of input (dynamic chained scan) - // - template - __device__ __forceinline__ - void ConsumeTile( - OffsetT /*num_items*/, - OffsetT num_remaining, - int tile_idx, - OffsetT tile_base, - ScanTileStateT& tile_state) + else { - // Load items - KeyT keys[ITEMS_PER_THREAD]; - OutputT values[ITEMS_PER_THREAD]; - OffsetT segment_flags[ITEMS_PER_THREAD]; - SizeValuePairT scan_items[ITEMS_PER_THREAD]; - - if (IS_LAST_TILE) - { - // Fill last element with the first element - // because collectives are not suffix guarded - BlockLoadKeysT(storage.load_keys) - .Load(d_keys_in + tile_base, - keys, - num_remaining, - *(d_keys_in + tile_base)); - } - else - { - BlockLoadKeysT(storage.load_keys) - .Load(d_keys_in + tile_base, keys); - } + BlockLoadValuesT(storage.load_values) + .Load(d_values_in + tile_base, values); + } - CTA_SYNC(); + CTA_SYNC(); - if (IS_LAST_TILE) - { - // Fill last element with the first element - // because collectives are not suffix guarded - BlockLoadValuesT(storage.load_values) - .Load(d_values_in + tile_base, - values, - num_remaining, - *(d_values_in + tile_base)); - } - else + // first tile + if (tile_idx == 0) + { + BlockDiscontinuityKeysT(storage.scan_storage.discontinuity) + .FlagHeads(segment_flags, keys, inequality_op); + + // Zip values and segment_flags + ZipValuesAndFlags(num_remaining, + values, + segment_flags, + scan_items); + + // Exclusive scan of values and segment_flags + SizeValuePairT tile_aggregate; + ScanTile(scan_items, tile_aggregate, Int2Type()); + + if (threadIdx.x == 0) + { + if (!IS_LAST_TILE) { - BlockLoadValuesT(storage.load_values) - .Load(d_values_in + tile_base, values); + tile_state.SetInclusive(0, tile_aggregate); } - CTA_SYNC(); - - // first tile - if (tile_idx == 0) - { - BlockDiscontinuityKeysT(storage.scan_storage.discontinuity) - .FlagHeads(segment_flags, keys, inequality_op); - - // Zip values and segment_flags - ZipValuesAndFlags(num_remaining, - values, - segment_flags, - scan_items); - - // Exclusive scan of values and segment_flags - SizeValuePairT tile_aggregate; - ScanTile(scan_items, tile_aggregate, Int2Type()); - - if (threadIdx.x == 0) - { - if (!IS_LAST_TILE) - tile_state.SetInclusive(0, tile_aggregate); - - scan_items[0].key = 0; - } - } - else - { - KeyT tile_pred_key = (threadIdx.x == 0) ? d_keys_prev_in[tile_idx] - : KeyT(); - - BlockDiscontinuityKeysT(storage.scan_storage.discontinuity) - .FlagHeads(segment_flags, keys, inequality_op, tile_pred_key); - - // Zip values and segment_flags - ZipValuesAndFlags(num_remaining, - values, - segment_flags, - scan_items); - - SizeValuePairT tile_aggregate; - TilePrefixCallbackT prefix_op(tile_state, - storage.scan_storage.prefix, - pair_scan_op, - tile_idx); - ScanTile(scan_items, - tile_aggregate, - prefix_op, - Int2Type()); - } + scan_items[0].key = 0; + } + } + else + { + KeyT tile_pred_key = (threadIdx.x == 0) ? d_keys_prev_in[tile_idx] + : KeyT(); + + BlockDiscontinuityKeysT(storage.scan_storage.discontinuity) + .FlagHeads(segment_flags, keys, inequality_op, tile_pred_key); + + // Zip values and segment_flags + ZipValuesAndFlags(num_remaining, + values, + segment_flags, + scan_items); + + SizeValuePairT tile_aggregate; + TilePrefixCallbackT prefix_op(tile_state, + storage.scan_storage.prefix, + pair_scan_op, + tile_idx); + ScanTile(scan_items, tile_aggregate, prefix_op, Int2Type()); + } - CTA_SYNC(); + CTA_SYNC(); - UnzipValues(values, scan_items); + UnzipValues(values, scan_items); - AddInitToScan(values, segment_flags); + AddInitToScan(values, segment_flags); - // Store items - if (IS_LAST_TILE) - { - BlockStoreValuesT(storage.store_values) - .Store(d_values_out + tile_base, values, num_remaining); - } - else - { - BlockStoreValuesT(storage.store_values) - .Store(d_values_out + tile_base, values); - } + // Store items + if (IS_LAST_TILE) + { + BlockStoreValuesT(storage.store_values) + .Store(d_values_out + tile_base, values, num_remaining); } - - //--------------------------------------------------------------------- - // Constructor - //--------------------------------------------------------------------- - - // Dequeue and scan tiles of items as part of a dynamic chained scan - // with Init functor - __device__ __forceinline__ - AgentScanByKey( - TempStorage & storage, - KeysInputIteratorT d_keys_in, - KeyT * d_keys_prev_in, - ValuesInputIteratorT d_values_in, - ValuesOutputIteratorT d_values_out, - EqualityOp equality_op, - ScanOpT scan_op, - InitValueT init_value) - : - storage(storage.Alias()), - d_keys_in(d_keys_in), - d_keys_prev_in(d_keys_prev_in), - d_values_in(d_values_in), - d_values_out(d_values_out), - inequality_op(equality_op), - scan_op(scan_op), - pair_scan_op(scan_op), - init_value(init_value) - {} - - /** - * Scan tiles of items as part of a dynamic chained scan - */ - __device__ __forceinline__ void ConsumeRange( - OffsetT num_items, ///< Total number of input items - ScanTileStateT& tile_state, ///< Global tile state descriptor - int start_tile) ///< The starting tile for the current grid + else { - int tile_idx = blockIdx.x; - OffsetT tile_base = OffsetT(ITEMS_PER_TILE) * tile_idx; - OffsetT num_remaining = num_items - tile_base; - - if (num_remaining > ITEMS_PER_TILE) - { - // Not the last tile (full) - ConsumeTile(num_items, - num_remaining, - tile_idx, - tile_base, - tile_state); - } - else if (num_remaining > 0) - { - // The last tile (possibly partially-full) - ConsumeTile(num_items, - num_remaining, - tile_idx, - tile_base, - tile_state); - } + BlockStoreValuesT(storage.store_values) + .Store(d_values_out + tile_base, values); + } + } + + //--------------------------------------------------------------------- + // Constructor + //--------------------------------------------------------------------- + + // Dequeue and scan tiles of items as part of a dynamic chained scan + // with Init functor + __device__ __forceinline__ AgentScanByKey(TempStorage &storage, + KeysInputIteratorT d_keys_in, + KeyT *d_keys_prev_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + EqualityOp equality_op, + ScanOpT scan_op, + InitValueT init_value) + : storage(storage.Alias()) + , d_keys_in(d_keys_in) + , d_keys_prev_in(d_keys_prev_in) + , d_values_in(d_values_in) + , d_values_out(d_values_out) + , inequality_op(equality_op) + , scan_op(scan_op) + , pair_scan_op(scan_op) + , init_value(init_value) + {} + + /** + * Scan tiles of items as part of a dynamic chained scan + * + * @param num_items + * Total number of input items + * + * @param tile_state + * Global tile state descriptor + * + * start_tile + * The starting tile for the current grid + */ + __device__ __forceinline__ void ConsumeRange(OffsetT num_items, + ScanTileStateT &tile_state, + int start_tile) + { + int tile_idx = blockIdx.x; + OffsetT tile_base = OffsetT(ITEMS_PER_TILE) * tile_idx; + OffsetT num_remaining = num_items - tile_base; + + if (num_remaining > ITEMS_PER_TILE) + { + // Not the last tile (full) + ConsumeTile(num_items, + num_remaining, + tile_idx, + tile_base, + tile_state); } + else if (num_remaining > 0) + { + // The last tile (possibly partially-full) + ConsumeTile(num_items, + num_remaining, + tile_idx, + tile_base, + tile_state); + } + } }; - CUB_NAMESPACE_END + diff --git a/cub/agent/single_pass_scan_operators.cuh b/cub/agent/single_pass_scan_operators.cuh index 4f21863a7f..63fbd7c85e 100644 --- a/cub/agent/single_pass_scan_operators.cuh +++ b/cub/agent/single_pass_scan_operators.cuh @@ -35,11 +35,12 @@ #include -#include "../thread/thread_load.cuh" -#include "../thread/thread_store.cuh" -#include "../warp/warp_reduce.cuh" -#include "../config.cuh" -#include "../util_device.cuh" +#include +#include +#include +#include +#include +#include CUB_NAMESPACE_BEGIN @@ -738,8 +739,10 @@ struct TilePrefixCallbackOp // Update our status with our tile-aggregate if (threadIdx.x == 0) { - temp_storage.block_aggregate = block_aggregate; - tile_status.SetPartial(tile_idx, block_aggregate); + detail::uninitialized_copy(&temp_storage.block_aggregate, + block_aggregate); + + tile_status.SetPartial(tile_idx, block_aggregate); } int predecessor_idx = tile_idx - threadIdx.x - 1; @@ -768,8 +771,11 @@ struct TilePrefixCallbackOp inclusive_prefix = scan_op(exclusive_prefix, block_aggregate); tile_status.SetInclusive(tile_idx, inclusive_prefix); - temp_storage.exclusive_prefix = exclusive_prefix; - temp_storage.inclusive_prefix = inclusive_prefix; + detail::uninitialized_copy(&temp_storage.exclusive_prefix, + exclusive_prefix); + + detail::uninitialized_copy(&temp_storage.inclusive_prefix, + inclusive_prefix); } // Return exclusive_prefix diff --git a/cub/block/block_exchange.cuh b/cub/block/block_exchange.cuh index 98ebfe7e82..dbbba5417e 100644 --- a/cub/block/block_exchange.cuh +++ b/cub/block/block_exchange.cuh @@ -33,10 +33,11 @@ #pragma once -#include "../config.cuh" -#include "../util_ptx.cuh" -#include "../util_type.cuh" -#include "../warp/warp_exchange.cuh" +#include +#include +#include +#include +#include CUB_NAMESPACE_BEGIN @@ -209,7 +210,8 @@ private: { int item_offset = (linear_tid * ITEMS_PER_THREAD) + ITEM; if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; - temp_storage.buff[item_offset] = input_items[ITEM]; + detail::uninitialized_copy(temp_storage.buff + item_offset, + input_items[ITEM]); } CTA_SYNC(); @@ -250,7 +252,8 @@ private: { int item_offset = (lane_id * ITEMS_PER_THREAD) + ITEM; if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; - temp_storage.buff[item_offset] = input_items[ITEM]; + detail::uninitialized_copy(temp_storage.buff + item_offset, + input_items[ITEM]); } } @@ -298,7 +301,8 @@ private: { int item_offset = warp_offset + ITEM + (lane_id * ITEMS_PER_THREAD); if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; - temp_storage.buff[item_offset] = input_items[ITEM]; + detail::uninitialized_copy(temp_storage.buff + item_offset, + input_items[ITEM]); } WARP_SYNC(0xffffffff); @@ -328,7 +332,8 @@ private: { int item_offset = ITEM + (lane_id * ITEMS_PER_THREAD); if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; - temp_storage.buff[item_offset] = input_items[ITEM]; + detail::uninitialized_copy(temp_storage.buff + item_offset, + input_items[ITEM]); } WARP_SYNC(0xffffffff); @@ -354,7 +359,8 @@ private: { int item_offset = ITEM + (lane_id * ITEMS_PER_THREAD); if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; - temp_storage.buff[item_offset] = input_items[ITEM]; + detail::uninitialized_copy(temp_storage.buff + item_offset, + input_items[ITEM]); } WARP_SYNC(0xffffffff); @@ -385,7 +391,8 @@ private: { int item_offset = int(ITEM * BLOCK_THREADS) + linear_tid; if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; - temp_storage.buff[item_offset] = input_items[ITEM]; + detail::uninitialized_copy(temp_storage.buff + item_offset, + input_items[ITEM]); } CTA_SYNC(); @@ -434,7 +441,9 @@ private: if ((item_offset >= 0) && (item_offset < TIME_SLICED_ITEMS)) { if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; - temp_storage.buff[item_offset] = input_items[ITEM]; + detail::uninitialized_copy(temp_storage.buff + + item_offset, + input_items[ITEM]); } } } @@ -476,7 +485,8 @@ private: { int item_offset = warp_offset + (ITEM * WARP_TIME_SLICED_THREADS) + lane_id; if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; - new (&temp_storage.buff[item_offset]) InputT (input_items[ITEM]); + detail::uninitialized_copy(temp_storage.buff + item_offset, + input_items[ITEM]); } WARP_SYNC(0xffffffff); @@ -486,7 +496,8 @@ private: { int item_offset = warp_offset + ITEM + (lane_id * ITEMS_PER_THREAD); if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; - new(&output_items[ITEM]) OutputT(temp_storage.buff[item_offset]); + detail::uninitialized_copy(output_items + ITEM, + temp_storage.buff[item_offset]); } } @@ -512,7 +523,8 @@ private: { int item_offset = (ITEM * WARP_TIME_SLICED_THREADS) + lane_id; if (INSERT_PADDING) item_offset += item_offset >> LOG_SMEM_BANKS; - temp_storage.buff[item_offset] = input_items[ITEM]; + detail::uninitialized_copy(temp_storage.buff + item_offset, + input_items[ITEM]); } WARP_SYNC(0xffffffff); @@ -544,7 +556,8 @@ private: { int item_offset = ranks[ITEM]; if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); - temp_storage.buff[item_offset] = input_items[ITEM]; + detail::uninitialized_copy(temp_storage.buff + item_offset, + input_items[ITEM]); } CTA_SYNC(); @@ -584,7 +597,8 @@ private: if ((item_offset >= 0) && (item_offset < WARP_TIME_SLICED_ITEMS)) { if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); - temp_storage.buff[item_offset] = input_items[ITEM]; + detail::uninitialized_copy(temp_storage.buff + item_offset, + input_items[ITEM]); } } @@ -626,7 +640,8 @@ private: { int item_offset = ranks[ITEM]; if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); - temp_storage.buff[item_offset] = input_items[ITEM]; + detail::uninitialized_copy(temp_storage.buff + item_offset, + input_items[ITEM]); } CTA_SYNC(); @@ -668,7 +683,8 @@ private: if ((item_offset >= 0) && (item_offset < WARP_TIME_SLICED_ITEMS)) { if (INSERT_PADDING) item_offset = SHR_ADD(item_offset, LOG_SMEM_BANKS, item_offset); - temp_storage.buff[item_offset] = input_items[ITEM]; + detail::uninitialized_copy(temp_storage.buff + item_offset, + input_items[ITEM]); } } diff --git a/cub/block/specializations/block_reduce_warp_reductions.cuh b/cub/block/specializations/block_reduce_warp_reductions.cuh index 4fec6cad1b..1a1b6a8f51 100644 --- a/cub/block/specializations/block_reduce_warp_reductions.cuh +++ b/cub/block/specializations/block_reduce_warp_reductions.cuh @@ -33,9 +33,10 @@ #pragma once -#include "../../warp/warp_reduce.cuh" -#include "../../config.cuh" -#include "../../util_ptx.cuh" +#include +#include +#include +#include CUB_NAMESPACE_BEGIN @@ -143,7 +144,8 @@ struct BlockReduceWarpReductions // Share lane aggregates if (lane_id == 0) { - new (temp_storage.warp_aggregates + warp_id) T(warp_aggregate); + detail::uninitialized_copy(temp_storage.warp_aggregates + warp_id, + warp_aggregate); } CTA_SYNC(); diff --git a/cub/block/specializations/block_scan_warp_scans.cuh b/cub/block/specializations/block_scan_warp_scans.cuh index ed550162e2..f76131a785 100644 --- a/cub/block/specializations/block_scan_warp_scans.cuh +++ b/cub/block/specializations/block_scan_warp_scans.cuh @@ -33,9 +33,10 @@ #pragma once -#include "../../config.cuh" -#include "../../util_ptx.cuh" -#include "../../warp/warp_scan.cuh" +#include +#include +#include +#include CUB_NAMESPACE_BEGIN @@ -151,7 +152,10 @@ struct BlockScanWarpScans { // Last lane in each warp shares its warp-aggregate if (lane_id == WARP_THREADS - 1) - temp_storage.warp_aggregates[warp_id] = warp_aggregate; + { + detail::uninitialized_copy(temp_storage.warp_aggregates + warp_id, + warp_aggregate); + } CTA_SYNC(); @@ -293,9 +297,11 @@ struct BlockScanWarpScans T block_prefix = block_prefix_callback_op(block_aggregate); if (lane_id == 0) { - // Share the prefix with all threads - temp_storage.block_prefix = block_prefix; - exclusive_output = block_prefix; // The block prefix is the exclusive output for tid0 + // Share the prefix with all threads + detail::uninitialized_copy(&temp_storage.block_prefix, + block_prefix); + + exclusive_output = block_prefix; // The block prefix is the exclusive output for tid0 } } @@ -367,7 +373,8 @@ struct BlockScanWarpScans if (lane_id == 0) { // Share the prefix with all threads - temp_storage.block_prefix = block_prefix; + detail::uninitialized_copy(&temp_storage.block_prefix, + block_prefix); } } diff --git a/cub/detail/uninitialized_copy.cuh b/cub/detail/uninitialized_copy.cuh new file mode 100644 index 0000000000..f652f9bdb0 --- /dev/null +++ b/cub/detail/uninitialized_copy.cuh @@ -0,0 +1,66 @@ +/****************************************************************************** + * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include + +#include + +CUB_NAMESPACE_BEGIN + + +namespace detail +{ + +template ::value, + int + >::type = 0> +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + *ptr = cuda::std::forward(val); +} + +template ::value, + int + >::type = 0> +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + new (ptr) T(cuda::std::forward(val)); +} + +} // namespace detail + + +CUB_NAMESPACE_END + diff --git a/cub/device/device_scan.cuh b/cub/device/device_scan.cuh index eb099a6a44..20cb8ba872 100644 --- a/cub/device/device_scan.cuh +++ b/cub/device/device_scan.cuh @@ -193,18 +193,15 @@ struct DeviceScan { // Signed integer type for global offsets using OffsetT = int; - - // The output value type -- used as the intermediate accumulator - // Use the input value type per https://wg21.link/P0571 - using OutputT = cub::detail::value_t; + using InitT = cub::detail::value_t; // Initial value - OutputT init_value = 0; + InitT init_value{}; return DispatchScan< - InputIteratorT, OutputIteratorT, Sum, detail::InputValue, + InputIteratorT, OutputIteratorT, Sum, detail::InputValue, OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_in, d_out, - Sum(), detail::InputValue(init_value), + Sum(), detail::InputValue(init_value), num_items, stream); } @@ -332,12 +329,11 @@ struct DeviceScan { CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG - return ExclusiveSum(d_temp_storage, - temp_storage_bytes, - d_data, - d_data, - num_items, - stream); + return ExclusiveSum(d_temp_storage, + temp_storage_bytes, + d_data, + num_items, + stream); } /** @@ -1211,12 +1207,11 @@ struct DeviceScan { CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG - return InclusiveSum(d_temp_storage, - temp_storage_bytes, - d_data, - d_data, - num_items, - stream); + return InclusiveSum(d_temp_storage, + temp_storage_bytes, + d_data, + num_items, + stream); } /** @@ -1612,24 +1607,21 @@ struct DeviceScan ValuesOutputIteratorT d_values_out, int num_items, EqualityOpT equality_op = EqualityOpT(), - cudaStream_t stream = 0) + cudaStream_t stream = 0) { // Signed integer type for global offsets using OffsetT = int; - - // The output value type -- used as the intermediate accumulator - // Use the input value type per https://wg21.link/P0571 - using OutputT = cub::detail::value_t; + using InitT = cub::detail::value_t; // Initial value - OutputT init_value = 0; + InitT init_value{}; return DispatchScanByKey::Dispatch(d_temp_storage, temp_storage_bytes, d_keys_in, @@ -1833,7 +1825,7 @@ struct DeviceScan InitValueT init_value, int num_items, EqualityOpT equality_op = EqualityOpT(), - cudaStream_t stream = 0) + cudaStream_t stream = 0) { // Signed integer type for global offsets using OffsetT = int ; @@ -2007,7 +1999,7 @@ struct DeviceScan ValuesOutputIteratorT d_values_out, int num_items, EqualityOpT equality_op = EqualityOpT(), - cudaStream_t stream = 0) + cudaStream_t stream = 0) { // Signed integer type for global offsets using OffsetT = int ; @@ -2206,7 +2198,7 @@ struct DeviceScan ScanOpT scan_op, int num_items, EqualityOpT equality_op = EqualityOpT(), - cudaStream_t stream = 0) + cudaStream_t stream = 0) { // Signed integer type for global offsets using OffsetT = int; diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh index c070fdd3d1..c51b884698 100644 --- a/cub/device/dispatch/dispatch_radix_sort.cuh +++ b/cub/device/dispatch/dispatch_radix_sort.cuh @@ -146,6 +146,7 @@ __global__ void RadixSortScanBinsKernel( OffsetT*, cub::Sum, OffsetT, + OffsetT, OffsetT> AgentScanT; diff --git a/cub/device/dispatch/dispatch_reduce.cuh b/cub/device/dispatch/dispatch_reduce.cuh index 0a313d48a3..ac434eb862 100644 --- a/cub/device/dispatch/dispatch_reduce.cuh +++ b/cub/device/dispatch/dispatch_reduce.cuh @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -124,7 +125,7 @@ __global__ void DeviceReduceKernel(InputIteratorT d_in, // Output result if (threadIdx.x == 0) { - new (d_out + blockIdx.x) AccumT(block_aggregate); + detail::uninitialized_copy(d_out + blockIdx.x, block_aggregate); } } diff --git a/cub/device/dispatch/dispatch_reduce_by_key.cuh b/cub/device/dispatch/dispatch_reduce_by_key.cuh index b6f6b7c972..738eef63da 100644 --- a/cub/device/dispatch/dispatch_reduce_by_key.cuh +++ b/cub/device/dispatch/dispatch_reduce_by_key.cuh @@ -1,6 +1,6 @@ /****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. - * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -13,10 +13,10 @@ * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND @@ -27,8 +27,8 @@ ******************************************************************************/ /** - * \file - * cub::DeviceReduceByKey provides device-wide, parallel operations for reducing segments of values residing within device-accessible memory. + * @file cub::DeviceReduceByKey provides device-wide, parallel operations for + * reducing segments of values residing within device-accessible memory. */ #pragma once @@ -56,448 +56,688 @@ CUB_NAMESPACE_BEGIN *****************************************************************************/ /** - * Multi-block reduce-by-key sweep kernel entry point + * @brief Multi-block reduce-by-key sweep kernel entry point + * + * @tparam AgentReduceByKeyPolicyT + * Parameterized AgentReduceByKeyPolicyT tuning policy type + * + * @tparam KeysInputIteratorT + * Random-access input iterator type for keys + * + * @tparam UniqueOutputIteratorT + * Random-access output iterator type for keys + * + * @tparam ValuesInputIteratorT + * Random-access input iterator type for values + * + * @tparam AggregatesOutputIteratorT + * Random-access output iterator type for values + * + * @tparam NumRunsOutputIteratorT + * Output iterator type for recording number of segments encountered + * + * @tparam ScanTileStateT + * Tile status interface type + * + * @tparam EqualityOpT + * KeyT equality operator type + * + * @tparam ReductionOpT + * ValueT reduction operator type + * + * @tparam OffsetT + * Signed integer type for global offsets + * + * @param d_keys_in + * Pointer to the input sequence of keys + * + * @param d_unique_out + * Pointer to the output sequence of unique keys (one key per run) + * + * @param d_values_in + * Pointer to the input sequence of corresponding values + * + * @param d_aggregates_out + * Pointer to the output sequence of value aggregates (one aggregate per run) + * + * @param d_num_runs_out + * Pointer to total number of runs encountered + * (i.e., the length of d_unique_out) + * + * @param tile_state + * Tile status interface + * + * @param start_tile + * The starting tile for the current grid + * + * @param equality_op + * KeyT equality operator + * + * @param reduction_op + * ValueT reduction operator + * + * @param num_items + * Total number of items to select from */ -template < - typename AgentReduceByKeyPolicyT, ///< Parameterized AgentReduceByKeyPolicyT tuning policy type - typename KeysInputIteratorT, ///< Random-access input iterator type for keys - typename UniqueOutputIteratorT, ///< Random-access output iterator type for keys - typename ValuesInputIteratorT, ///< Random-access input iterator type for values - typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values - typename NumRunsOutputIteratorT, ///< Output iterator type for recording number of segments encountered - typename ScanTileStateT, ///< Tile status interface type - typename EqualityOpT, ///< KeyT equality operator type - typename ReductionOpT, ///< ValueT reduction operator type - typename OffsetT> ///< Signed integer type for global offsets -__launch_bounds__ (int(AgentReduceByKeyPolicyT::BLOCK_THREADS)) -__global__ void DeviceReduceByKeyKernel( - KeysInputIteratorT d_keys_in, ///< Pointer to the input sequence of keys - UniqueOutputIteratorT d_unique_out, ///< Pointer to the output sequence of unique keys (one key per run) - ValuesInputIteratorT d_values_in, ///< Pointer to the input sequence of corresponding values - AggregatesOutputIteratorT d_aggregates_out, ///< Pointer to the output sequence of value aggregates (one aggregate per run) - NumRunsOutputIteratorT d_num_runs_out, ///< Pointer to total number of runs encountered (i.e., the length of d_unique_out) - ScanTileStateT tile_state, ///< Tile status interface - int start_tile, ///< The starting tile for the current grid - EqualityOpT equality_op, ///< KeyT equality operator - ReductionOpT reduction_op, ///< ValueT reduction operator - OffsetT num_items) ///< Total number of items to select from +template +__launch_bounds__(int(AgentReduceByKeyPolicyT::BLOCK_THREADS)) __global__ + void DeviceReduceByKeyKernel(KeysInputIteratorT d_keys_in, + UniqueOutputIteratorT d_unique_out, + ValuesInputIteratorT d_values_in, + AggregatesOutputIteratorT d_aggregates_out, + NumRunsOutputIteratorT d_num_runs_out, + ScanTileStateT tile_state, + int start_tile, + EqualityOpT equality_op, + ReductionOpT reduction_op, + OffsetT num_items) { - // Thread block type for reducing tiles of value segments - typedef AgentReduceByKey< - AgentReduceByKeyPolicyT, - KeysInputIteratorT, - UniqueOutputIteratorT, - ValuesInputIteratorT, - AggregatesOutputIteratorT, - NumRunsOutputIteratorT, - EqualityOpT, - ReductionOpT, - OffsetT> - AgentReduceByKeyT; - - // Shared memory for AgentReduceByKey - __shared__ typename AgentReduceByKeyT::TempStorage temp_storage; - - // Process tiles - AgentReduceByKeyT(temp_storage, d_keys_in, d_unique_out, d_values_in, d_aggregates_out, d_num_runs_out, equality_op, reduction_op).ConsumeRange( - num_items, - tile_state, - start_tile); + // Thread block type for reducing tiles of value segments + using AgentReduceByKeyT = AgentReduceByKey; + + // Shared memory for AgentReduceByKey + __shared__ typename AgentReduceByKeyT::TempStorage temp_storage; + + // Process tiles + AgentReduceByKeyT(temp_storage, + d_keys_in, + d_unique_out, + d_values_in, + d_aggregates_out, + d_num_runs_out, + equality_op, + reduction_op) + .ConsumeRange(num_items, tile_state, start_tile); } - - - /****************************************************************************** * Dispatch ******************************************************************************/ /** - * Utility class for dispatching the appropriately-tuned kernels for DeviceReduceByKey + * @brief Utility class for dispatching the appropriately-tuned kernels for + * DeviceReduceByKey + * + * @tparam KeysInputIteratorT + * Random-access input iterator type for keys + * + * @tparam UniqueOutputIteratorT + * Random-access output iterator type for keys + * + * @tparam ValuesInputIteratorT + * Random-access input iterator type for values + * + * @tparam AggregatesOutputIteratorT + * Random-access output iterator type for values + * + * @tparam NumRunsOutputIteratorT + * Output iterator type for recording number of segments encountered + * + * @tparam EqualityOpT + * KeyT equality operator type + * + * @tparam ReductionOpT + * ValueT reduction operator type + * + * @tparam OffsetT + * Signed integer type for global offsets + * */ -template < - typename KeysInputIteratorT, ///< Random-access input iterator type for keys - typename UniqueOutputIteratorT, ///< Random-access output iterator type for keys - typename ValuesInputIteratorT, ///< Random-access input iterator type for values - typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values - typename NumRunsOutputIteratorT, ///< Output iterator type for recording number of segments encountered - typename EqualityOpT, ///< KeyT equality operator type - typename ReductionOpT, ///< ValueT reduction operator type - typename OffsetT> ///< Signed integer type for global offsets +template , + cub::detail::value_t>> struct DispatchReduceByKey { - //------------------------------------------------------------------------- - // Types and constants - //------------------------------------------------------------------------- - - // The input keys type - using KeyInputT = cub::detail::value_t; - - // The output keys type - using KeyOutputT = - cub::detail::non_void_value_t; - - // The input values type - using ValueInputT = cub::detail::value_t; - - // The output values type - using ValueOutputT = - cub::detail::non_void_value_t; - - enum + //------------------------------------------------------------------------- + // Types and constants + //------------------------------------------------------------------------- + + // The input keys type + using KeyInputT = cub::detail::value_t; + + // The output keys type + using KeyOutputT = + cub::detail::non_void_value_t; + + // The input values type + using ValueInputT = cub::detail::value_t; + + static constexpr int INIT_KERNEL_THREADS = 128; + + static constexpr int MAX_INPUT_BYTES = CUB_MAX(sizeof(KeyOutputT), + sizeof(AccumT)); + + static constexpr int COMBINED_INPUT_BYTES = sizeof(KeyOutputT) + + sizeof(AccumT); + + // Tile status descriptor interface type + using ScanTileStateT = ReduceByKeyScanTileState; + + //------------------------------------------------------------------------- + // Tuning policies + //------------------------------------------------------------------------- + + /// SM35 + struct Policy350 + { + static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 6; + static constexpr int ITEMS_PER_THREAD = + (MAX_INPUT_BYTES <= 8) + ? 6 + : CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, + CUB_MAX(1, + ((NOMINAL_4B_ITEMS_PER_THREAD * 8) + + COMBINED_INPUT_BYTES - 1) / + COMBINED_INPUT_BYTES)); + + using ReduceByKeyPolicyT = AgentReduceByKeyPolicy<128, + ITEMS_PER_THREAD, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + BLOCK_SCAN_WARP_SCANS>; + }; + + /****************************************************************************** + * Tuning policies of current PTX compiler pass + ******************************************************************************/ + + using PtxPolicy = Policy350; + + // "Opaque" policies (whose parameterizations aren't reflected in the type + // signature) + struct PtxReduceByKeyPolicy : PtxPolicy::ReduceByKeyPolicyT + {}; + + /****************************************************************************** + * Utilities + ******************************************************************************/ + + /** + * Initialize kernel dispatch configurations with the policies corresponding + * to the PTX assembly we will use + */ + template + CUB_RUNTIME_FUNCTION __forceinline__ static void + InitConfigs(int /*ptx_version*/, KernelConfig &reduce_by_key_config) + { + NV_IF_TARGET(NV_IS_DEVICE, + ( + // We're on the device, so initialize the kernel dispatch + // configurations with the current PTX policy + reduce_by_key_config.template Init();), + ( + // We're on the host, so lookup and initialize the kernel + // dispatch configurations with the policies that match the + // device's PTX version + + // (There's only one policy right now) + reduce_by_key_config + .template Init();)); + } + + /** + * Kernel kernel dispatch configuration. + */ + struct KernelConfig + { + int block_threads; + int items_per_thread; + int tile_items; + + template + CUB_RUNTIME_FUNCTION __forceinline__ void Init() { - INIT_KERNEL_THREADS = 128, - MAX_INPUT_BYTES = CUB_MAX(sizeof(KeyOutputT), sizeof(ValueOutputT)), - COMBINED_INPUT_BYTES = sizeof(KeyOutputT) + sizeof(ValueOutputT), - }; - - // Tile status descriptor interface type - using ScanTileStateT = ReduceByKeyScanTileState; - - //------------------------------------------------------------------------- - // Tuning policies - //------------------------------------------------------------------------- - - /// SM35 - struct Policy350 - { - enum { - NOMINAL_4B_ITEMS_PER_THREAD = 6, - ITEMS_PER_THREAD = (MAX_INPUT_BYTES <= 8) ? 6 : CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, CUB_MAX(1, ((NOMINAL_4B_ITEMS_PER_THREAD * 8) + COMBINED_INPUT_BYTES - 1) / COMBINED_INPUT_BYTES)), - }; - - typedef AgentReduceByKeyPolicy< - 128, - ITEMS_PER_THREAD, - BLOCK_LOAD_DIRECT, - LOAD_LDG, - BLOCK_SCAN_WARP_SCANS> - ReduceByKeyPolicyT; - }; - - /****************************************************************************** - * Tuning policies of current PTX compiler pass - ******************************************************************************/ - - typedef Policy350 PtxPolicy; - - // "Opaque" policies (whose parameterizations aren't reflected in the type signature) - struct PtxReduceByKeyPolicy : PtxPolicy::ReduceByKeyPolicyT {}; - - - /****************************************************************************** - * Utilities - ******************************************************************************/ - - /** - * Initialize kernel dispatch configurations with the policies corresponding to the PTX assembly we will use - */ - template - CUB_RUNTIME_FUNCTION __forceinline__ - static void InitConfigs( - int /*ptx_version*/, - KernelConfig &reduce_by_key_config) - { - NV_IF_TARGET(NV_IS_DEVICE, - ( - // We're on the device, so initialize the kernel dispatch configurations with the current PTX policy - reduce_by_key_config.template Init(); - ), ( - // We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version - - // (There's only one policy right now) - reduce_by_key_config.template Init(); - )); + block_threads = PolicyT::BLOCK_THREADS; + items_per_thread = PolicyT::ITEMS_PER_THREAD; + tile_items = block_threads * items_per_thread; } - - - /** - * Kernel kernel dispatch configuration. - */ - struct KernelConfig + }; + + //--------------------------------------------------------------------- + // Dispatch entrypoints + //--------------------------------------------------------------------- + + /** + * @brief Internal dispatch routine for computing a device-wide + * reduce-by-key using the specified kernel functions. + * + * @tparam ScanInitKernelT + * Function type of cub::DeviceScanInitKernel + * + * @tparam ReduceByKeyKernelT + * Function type of cub::DeviceReduceByKeyKernelT + * + * @param[in] d_temp_storage + * Device-accessible allocation of temporary storage. When `nullptr`, the + * required allocation size is written to `temp_storage_bytes` and no + * work is done. + * + * @param[in,out] temp_storage_bytes + * Reference to size in bytes of `d_temp_storage` allocation + * + * @param[in] d_keys_in + * Pointer to the input sequence of keys + * + * @param[out] d_unique_out + * Pointer to the output sequence of unique keys (one key per run) + * + * @param[in] d_values_in + * Pointer to the input sequence of corresponding values + * + * @param[out] d_aggregates_out + * Pointer to the output sequence of value aggregates + * (one aggregate per run) + * + * @param[out] d_num_runs_out + * Pointer to total number of runs encountered + * (i.e., the length of d_unique_out) + * + * @param[in] equality_op + * KeyT equality operator + * + * @param[in] reduction_op + * ValueT reduction operator + * + * @param[in] num_items + * Total number of items to select from + * + * @param[in] stream + * CUDA stream to launch kernels within. Default is stream0. + * + * @param[in] ptx_version + * PTX version of dispatch kernels + * + * @param[in] init_kernel + * Kernel function pointer to parameterization of + * cub::DeviceScanInitKernel + * + * @param[in] reduce_by_key_kernel + * Kernel function pointer to parameterization of + * cub::DeviceReduceByKeyKernel + * + * @param[in] reduce_by_key_config + * Dispatch parameters that match the policy that + * `reduce_by_key_kernel` was compiled for + */ + template + CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t + Dispatch(void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + UniqueOutputIteratorT d_unique_out, + ValuesInputIteratorT d_values_in, + AggregatesOutputIteratorT d_aggregates_out, + NumRunsOutputIteratorT d_num_runs_out, + EqualityOpT equality_op, + ReductionOpT reduction_op, + OffsetT num_items, + cudaStream_t stream, + int /*ptx_version*/, + ScanInitKernelT init_kernel, + ReduceByKeyKernelT reduce_by_key_kernel, + KernelConfig reduce_by_key_config) + { + cudaError error = cudaSuccess; + do { - int block_threads; - int items_per_thread; - int tile_items; + // Get device ordinal + int device_ordinal; + if (CubDebug(error = cudaGetDevice(&device_ordinal))) + { + break; + } + + // Number of input tiles + int tile_size = reduce_by_key_config.block_threads * + reduce_by_key_config.items_per_thread; + int num_tiles = + static_cast(cub::DivideAndRoundUp(num_items, tile_size)); + + // Specify temporary storage allocation requirements + size_t allocation_sizes[1]; + if (CubDebug(error = ScanTileStateT::AllocationSize(num_tiles, + allocation_sizes[0]))) + { + break; // bytes needed for tile status descriptors + } + + // Compute allocation pointers into the single storage blob (or compute + // the necessary size of the blob) + void *allocations[1] = {}; + if (CubDebug(error = AliasTemporaries(d_temp_storage, + temp_storage_bytes, + allocations, + allocation_sizes))) + { + break; + } + + if (d_temp_storage == NULL) + { + // Return if the caller is simply requesting the size of the storage + // allocation + break; + } + + // Construct the tile status interface + ScanTileStateT tile_state; + if (CubDebug(error = tile_state.Init(num_tiles, + allocations[0], + allocation_sizes[0]))) + { + break; + } + + // Log init_kernel configuration + int init_grid_size = + CUB_MAX(1, cub::DivideAndRoundUp(num_tiles, INIT_KERNEL_THREADS)); + + #ifdef CUB_DETAIL_DEBUG_ENABLE_LOG + _CubLog("Invoking init_kernel<<<%d, %d, 0, %lld>>>()\n", + init_grid_size, + INIT_KERNEL_THREADS, + (long long)stream); + #endif + + // Invoke init_kernel to initialize tile descriptors + THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( + init_grid_size, + INIT_KERNEL_THREADS, + 0, + stream) + .doit(init_kernel, tile_state, num_tiles, d_num_runs_out); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) + { + break; + } + + // Sync the stream if specified to flush runtime errors + error = detail::DebugSyncStream(stream); + if (CubDebug(error)) + { + break; + } + + // Return if empty problem + if (num_items == 0) + { + break; + } + + // Get SM occupancy for reduce_by_key_kernel + int reduce_by_key_sm_occupancy; + if (CubDebug(error = MaxSmOccupancy(reduce_by_key_sm_occupancy, + reduce_by_key_kernel, + reduce_by_key_config.block_threads))) + { + break; + } + + // Get max x-dimension of grid + int max_dim_x; + if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, + cudaDevAttrMaxGridDimX, + device_ordinal))) + { + break; + } + + // Run grids in epochs (in case number of tiles exceeds max x-dimension + int scan_grid_size = CUB_MIN(num_tiles, max_dim_x); + for (int start_tile = 0; start_tile < num_tiles; + start_tile += scan_grid_size) + { + // Log reduce_by_key_kernel configuration + #ifdef CUB_DETAIL_DEBUG_ENABLE_LOG + _CubLog("Invoking %d reduce_by_key_kernel<<<%d, %d, 0, %lld>>>(), %d " + "items per thread, %d SM occupancy\n", + start_tile, + scan_grid_size, + reduce_by_key_config.block_threads, + (long long)stream, + reduce_by_key_config.items_per_thread, + reduce_by_key_sm_occupancy); + #endif + + // Invoke reduce_by_key_kernel + THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( + scan_grid_size, + reduce_by_key_config.block_threads, + 0, + stream) + .doit(reduce_by_key_kernel, + d_keys_in, + d_unique_out, + d_values_in, + d_aggregates_out, + d_num_runs_out, + tile_state, + start_tile, + equality_op, + reduction_op, + num_items); - template - CUB_RUNTIME_FUNCTION __forceinline__ - void Init() + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) { - block_threads = PolicyT::BLOCK_THREADS; - items_per_thread = PolicyT::ITEMS_PER_THREAD; - tile_items = block_threads * items_per_thread; + break; } - }; - - - //--------------------------------------------------------------------- - // Dispatch entrypoints - //--------------------------------------------------------------------- - - /** - * Internal dispatch routine for computing a device-wide reduce-by-key using the - * specified kernel functions. - */ - template < - typename ScanInitKernelT, ///< Function type of cub::DeviceScanInitKernel - typename ReduceByKeyKernelT> ///< Function type of cub::DeviceReduceByKeyKernelT - CUB_RUNTIME_FUNCTION __forceinline__ - static cudaError_t Dispatch( - void* d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. - size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - KeysInputIteratorT d_keys_in, ///< [in] Pointer to the input sequence of keys - UniqueOutputIteratorT d_unique_out, ///< [out] Pointer to the output sequence of unique keys (one key per run) - ValuesInputIteratorT d_values_in, ///< [in] Pointer to the input sequence of corresponding values - AggregatesOutputIteratorT d_aggregates_out, ///< [out] Pointer to the output sequence of value aggregates (one aggregate per run) - NumRunsOutputIteratorT d_num_runs_out, ///< [out] Pointer to total number of runs encountered (i.e., the length of d_unique_out) - EqualityOpT equality_op, ///< [in] KeyT equality operator - ReductionOpT reduction_op, ///< [in] ValueT reduction operator - OffsetT num_items, ///< [in] Total number of items to select from - cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. - int /*ptx_version*/, ///< [in] PTX version of dispatch kernels - ScanInitKernelT init_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceScanInitKernel - ReduceByKeyKernelT reduce_by_key_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceReduceByKeyKernel - KernelConfig reduce_by_key_config) ///< [in] Dispatch parameters that match the policy that \p reduce_by_key_kernel was compiled for - { - cudaError error = cudaSuccess; - do + + // Sync the stream if specified to flush runtime errors + error = detail::DebugSyncStream(stream); + if (CubDebug(error)) { - // Get device ordinal - int device_ordinal; - if (CubDebug(error = cudaGetDevice(&device_ordinal))) break; - - // Number of input tiles - int tile_size = reduce_by_key_config.block_threads * reduce_by_key_config.items_per_thread; - int num_tiles = static_cast(cub::DivideAndRoundUp(num_items, tile_size)); - - // Specify temporary storage allocation requirements - size_t allocation_sizes[1]; - if (CubDebug(error = ScanTileStateT::AllocationSize(num_tiles, allocation_sizes[0]))) break; // bytes needed for tile status descriptors - - // Compute allocation pointers into the single storage blob (or compute the necessary size of the blob) - void* allocations[1] = {}; - if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; - if (d_temp_storage == NULL) - { - // Return if the caller is simply requesting the size of the storage allocation - break; - } - - // Construct the tile status interface - ScanTileStateT tile_state; - if (CubDebug(error = tile_state.Init(num_tiles, allocations[0], allocation_sizes[0]))) break; - - // Log init_kernel configuration - int init_grid_size = CUB_MAX(1, cub::DivideAndRoundUp(num_tiles, INIT_KERNEL_THREADS)); - #ifdef CUB_DETAIL_DEBUG_ENABLE_LOG - _CubLog("Invoking init_kernel<<<%d, %d, 0, %lld>>>()\n", init_grid_size, INIT_KERNEL_THREADS, (long long) stream); - #endif - - // Invoke init_kernel to initialize tile descriptors - THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( - init_grid_size, INIT_KERNEL_THREADS, 0, stream - ).doit(init_kernel, - tile_state, - num_tiles, - d_num_runs_out); - - // Check for failure to launch - if (CubDebug(error = cudaPeekAtLastError())) - { - break; - } - - // Sync the stream if specified to flush runtime errors - error = detail::DebugSyncStream(stream); - if (CubDebug(error)) - { - break; - } - - // Return if empty problem - if (num_items == 0) - { - break; - } - - // Get SM occupancy for reduce_by_key_kernel - int reduce_by_key_sm_occupancy; - if (CubDebug(error = MaxSmOccupancy( - reduce_by_key_sm_occupancy, // out - reduce_by_key_kernel, - reduce_by_key_config.block_threads))) break; - - // Get max x-dimension of grid - int max_dim_x; - if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, cudaDevAttrMaxGridDimX, device_ordinal))) break;; - - // Run grids in epochs (in case number of tiles exceeds max x-dimension - int scan_grid_size = CUB_MIN(num_tiles, max_dim_x); - for (int start_tile = 0; start_tile < num_tiles; start_tile += scan_grid_size) - { - // Log reduce_by_key_kernel configuration - #ifdef CUB_DETAIL_DEBUG_ENABLE_LOG - _CubLog("Invoking %d reduce_by_key_kernel<<<%d, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy\n", - start_tile, scan_grid_size, reduce_by_key_config.block_threads, (long long) stream, reduce_by_key_config.items_per_thread, reduce_by_key_sm_occupancy); - #endif - - // Invoke reduce_by_key_kernel - THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( - scan_grid_size, reduce_by_key_config.block_threads, 0, - stream - ).doit(reduce_by_key_kernel, + break; + } + } + } while (0); + + return error; + } + + template + CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED + CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t + Dispatch(void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + UniqueOutputIteratorT d_unique_out, + ValuesInputIteratorT d_values_in, + AggregatesOutputIteratorT d_aggregates_out, + NumRunsOutputIteratorT d_num_runs_out, + EqualityOpT equality_op, + ReductionOpT reduction_op, + OffsetT num_items, + cudaStream_t stream, + bool debug_synchronous, + int ptx_version, + ScanInitKernelT init_kernel, + ReduceByKeyKernelT reduce_by_key_kernel, + KernelConfig reduce_by_key_config) + { + CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG + + return Dispatch(d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_unique_out, + d_values_in, + d_aggregates_out, + d_num_runs_out, + equality_op, + reduction_op, + num_items, + stream, + ptx_version, + init_kernel, + reduce_by_key_kernel, + reduce_by_key_config); + } + + /** + * Internal dispatch routine + * @param[in] d_temp_storage + * Device-accessible allocation of temporary storage. When `nullptr`, the + * required allocation size is written to `temp_storage_bytes` and no + * work is done. + * + * @param[in,out] temp_storage_bytes + * Reference to size in bytes of `d_temp_storage` allocation + * + * @param[in] d_keys_in + * Pointer to the input sequence of keys + * + * @param[out] d_unique_out + * Pointer to the output sequence of unique keys (one key per run) + * + * @param[in] d_values_in + * Pointer to the input sequence of corresponding values + * + * @param[out] d_aggregates_out + * Pointer to the output sequence of value aggregates + * (one aggregate per run) + * + * @param[out] d_num_runs_out + * Pointer to total number of runs encountered + * (i.e., the length of d_unique_out) + * + * @param[in] equality_op + * KeyT equality operator + * + * @param[in] reduction_op + * ValueT reduction operator + * + * @param[in] num_items + * Total number of items to select from + * + * @param[in] stream + * CUDA stream to launch kernels within. Default is stream0. + */ + CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t + Dispatch(void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + UniqueOutputIteratorT d_unique_out, + ValuesInputIteratorT d_values_in, + AggregatesOutputIteratorT d_aggregates_out, + NumRunsOutputIteratorT d_num_runs_out, + EqualityOpT equality_op, + ReductionOpT reduction_op, + OffsetT num_items, + cudaStream_t stream) + { + cudaError error = cudaSuccess; + + do + { + // Get PTX version + int ptx_version = 0; + if (CubDebug(error = PtxVersion(ptx_version))) + { + break; + } + + // Get kernel kernel dispatch configurations + KernelConfig reduce_by_key_config; + InitConfigs(ptx_version, reduce_by_key_config); + + // Dispatch + if (CubDebug( + error = Dispatch( + d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_unique_out, + d_values_in, + d_aggregates_out, + d_num_runs_out, + equality_op, + reduction_op, + num_items, + stream, + ptx_version, + DeviceCompactInitKernel, + DeviceReduceByKeyKernel, + reduce_by_key_config))) + { + break; + } + } while (0); + + return error; + } + + CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED + CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t + Dispatch(void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + UniqueOutputIteratorT d_unique_out, + ValuesInputIteratorT d_values_in, + AggregatesOutputIteratorT d_aggregates_out, + NumRunsOutputIteratorT d_num_runs_out, + EqualityOpT equality_op, + ReductionOpT reduction_op, + OffsetT num_items, + cudaStream_t stream, + bool debug_synchronous) + { + CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG + + return Dispatch(d_temp_storage, + temp_storage_bytes, d_keys_in, d_unique_out, d_values_in, d_aggregates_out, d_num_runs_out, - tile_state, - start_tile, equality_op, reduction_op, - num_items); - - // Check for failure to launch - if (CubDebug(error = cudaPeekAtLastError())) - { - break; - } - - // Sync the stream if specified to flush runtime errors - error = detail::DebugSyncStream(stream); - if (CubDebug(error)) - { - break; - } - } - } - while (0); - - return error; - } - - template - CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED - CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t - Dispatch(void *d_temp_storage, - size_t &temp_storage_bytes, - KeysInputIteratorT d_keys_in, - UniqueOutputIteratorT d_unique_out, - ValuesInputIteratorT d_values_in, - AggregatesOutputIteratorT d_aggregates_out, - NumRunsOutputIteratorT d_num_runs_out, - EqualityOpT equality_op, - ReductionOpT reduction_op, - OffsetT num_items, - cudaStream_t stream, - bool debug_synchronous, - int ptx_version, - ScanInitKernelT init_kernel, - ReduceByKeyKernelT reduce_by_key_kernel, - KernelConfig reduce_by_key_config) - { - CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG - - return Dispatch(d_temp_storage, - temp_storage_bytes, - d_keys_in, - d_unique_out, - d_values_in, - d_aggregates_out, - d_num_runs_out, - equality_op, - reduction_op, - num_items, - stream, - ptx_version, - init_kernel, - reduce_by_key_kernel, - reduce_by_key_config); - } - - /** - * Internal dispatch routine - */ - CUB_RUNTIME_FUNCTION __forceinline__ - static cudaError_t Dispatch( - void* d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. - size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - KeysInputIteratorT d_keys_in, ///< [in] Pointer to the input sequence of keys - UniqueOutputIteratorT d_unique_out, ///< [out] Pointer to the output sequence of unique keys (one key per run) - ValuesInputIteratorT d_values_in, ///< [in] Pointer to the input sequence of corresponding values - AggregatesOutputIteratorT d_aggregates_out, ///< [out] Pointer to the output sequence of value aggregates (one aggregate per run) - NumRunsOutputIteratorT d_num_runs_out, ///< [out] Pointer to total number of runs encountered (i.e., the length of d_unique_out) - EqualityOpT equality_op, ///< [in] KeyT equality operator - ReductionOpT reduction_op, ///< [in] ValueT reduction operator - OffsetT num_items, ///< [in] Total number of items to select from - cudaStream_t stream) ///< [in] CUDA stream to launch kernels within. Default is stream0. - { - cudaError error = cudaSuccess; - do - { - // Get PTX version - int ptx_version = 0; - if (CubDebug(error = PtxVersion(ptx_version))) break; - - // Get kernel kernel dispatch configurations - KernelConfig reduce_by_key_config; - InitConfigs(ptx_version, reduce_by_key_config); - - // Dispatch - if (CubDebug(error = Dispatch( - d_temp_storage, - temp_storage_bytes, - d_keys_in, - d_unique_out, - d_values_in, - d_aggregates_out, - d_num_runs_out, - equality_op, - reduction_op, - num_items, - stream, - ptx_version, - DeviceCompactInitKernel, - DeviceReduceByKeyKernel, - reduce_by_key_config))) break; - } - while (0); - - return error; - } - - CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED - CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t - Dispatch(void *d_temp_storage, - size_t &temp_storage_bytes, - KeysInputIteratorT d_keys_in, - UniqueOutputIteratorT d_unique_out, - ValuesInputIteratorT d_values_in, - AggregatesOutputIteratorT d_aggregates_out, - NumRunsOutputIteratorT d_num_runs_out, - EqualityOpT equality_op, - ReductionOpT reduction_op, - OffsetT num_items, - cudaStream_t stream, - bool debug_synchronous) - { - CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG - - return Dispatch(d_temp_storage, - temp_storage_bytes, - d_keys_in, - d_unique_out, - d_values_in, - d_aggregates_out, - d_num_runs_out, - equality_op, - reduction_op, - num_items, - stream); - } + num_items, + stream); + } }; CUB_NAMESPACE_END - diff --git a/cub/device/dispatch/dispatch_scan.cuh b/cub/device/dispatch/dispatch_scan.cuh index 1d5267253a..0df89b7c45 100644 --- a/cub/device/dispatch/dispatch_scan.cuh +++ b/cub/device/dispatch/dispatch_scan.cuh @@ -1,7 +1,7 @@ /****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. - * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: @@ -14,10 +14,10 @@ * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND @@ -28,8 +28,9 @@ ******************************************************************************/ /** - * \file - * cub::DeviceScan provides device-wide, parallel operations for computing a prefix scan across a sequence of data items residing within device-accessible memory. + * @file cub::DeviceScan provides device-wide, parallel operations for + * computing a prefix scan across a sequence of data items residing + * within device-accessible memory. */ #pragma once @@ -54,453 +55,642 @@ CUB_NAMESPACE_BEGIN *****************************************************************************/ /** - * Initialization kernel for tile status initialization (multi-block) + * @brief Initialization kernel for tile status initialization (multi-block) + * + * @tparam ScanTileStateT + * Tile status interface type + * + * @param[in] tile_state + * Tile status interface + * + * @param[in] num_tiles + * Number of tiles */ -template < - typename ScanTileStateT> ///< Tile status interface type -__global__ void DeviceScanInitKernel( - ScanTileStateT tile_state, ///< [in] Tile status interface - int num_tiles) ///< [in] Number of tiles +template +__global__ void DeviceScanInitKernel(ScanTileStateT tile_state, int num_tiles) { - // Initialize tile status - tile_state.InitializeStatus(num_tiles); + // Initialize tile status + tile_state.InitializeStatus(num_tiles); } /** * Initialization kernel for tile status initialization (multi-block) + * + * @tparam ScanTileStateT + * Tile status interface type + * + * @tparam NumSelectedIteratorT + * Output iterator type for recording the number of items selected + * + * @param[in] tile_state + * Tile status interface + * + * @param[in] num_tiles + * Number of tiles + * + * @param[out] d_num_selected_out + * Pointer to the total number of items selected + * (i.e., length of `d_selected_out`) */ -template < - typename ScanTileStateT, ///< Tile status interface type - typename NumSelectedIteratorT> ///< Output iterator type for recording the number of items selected -__global__ void DeviceCompactInitKernel( - ScanTileStateT tile_state, ///< [in] Tile status interface - int num_tiles, ///< [in] Number of tiles - NumSelectedIteratorT d_num_selected_out) ///< [out] Pointer to the total number of items selected (i.e., length of \p d_selected_out) +template +__global__ void DeviceCompactInitKernel(ScanTileStateT tile_state, + int num_tiles, + NumSelectedIteratorT d_num_selected_out) { - // Initialize tile status - tile_state.InitializeStatus(num_tiles); - - // Initialize d_num_selected_out - if ((blockIdx.x == 0) && (threadIdx.x == 0)) - *d_num_selected_out = 0; + // Initialize tile status + tile_state.InitializeStatus(num_tiles); + + // Initialize d_num_selected_out + if ((blockIdx.x == 0) && (threadIdx.x == 0)) + { + *d_num_selected_out = 0; + } } - /** - * Scan kernel entry point (multi-block) + * @brief Scan kernel entry point (multi-block) + * + * + * @tparam ChainedPolicyT + * Chained tuning policy + * + * @tparam InputIteratorT + * Random-access input iterator type for reading scan inputs \iterator + * + * @tparam OutputIteratorT + * Random-access output iterator type for writing scan outputs \iterator + * + * @tparam ScanTileStateT + * Tile status interface type + * + * @tparam ScanOpT + * Binary scan functor type having member + * `auto operator()(const T &a, const U &b)` + * + * @tparam InitValueT + * Initial value to seed the exclusive scan + * (cub::NullType for inclusive scans) + * + * @tparam OffsetT + * Signed integer type for global offsets + * + * @paramInput d_in + * data + * + * @paramOutput d_out + * data + * + * @paramTile tile_state + * status interface + * + * @paramThe start_tile + * starting tile for the current grid + * + * @paramBinary scan_op + * scan functor + * + * @paramInitial init_value + * value to seed the exclusive scan + * + * @paramTotal num_items + * number of scan items for the entire problem */ -template < - typename ChainedPolicyT, ///< Chained tuning policy - typename InputIteratorT, ///< Random-access input iterator type for reading scan inputs \iterator - typename OutputIteratorT, ///< Random-access output iterator type for writing scan outputs \iterator - typename ScanTileStateT, ///< Tile status interface type - typename ScanOpT, ///< Binary scan functor type having member T operator()(const T &a, const T &b) - typename InitValueT, ///< Initial value to seed the exclusive scan (cub::NullType for inclusive scans) - typename OffsetT> ///< Signed integer type for global offsets -__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS)) -__global__ void DeviceScanKernel( - InputIteratorT d_in, ///< Input data - OutputIteratorT d_out, ///< Output data - ScanTileStateT tile_state, ///< Tile status interface - int start_tile, ///< The starting tile for the current grid - ScanOpT scan_op, ///< Binary scan functor - InitValueT init_value, ///< Initial value to seed the exclusive scan - OffsetT num_items) ///< Total number of scan items for the entire problem +template +__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS)) + __global__ void DeviceScanKernel(InputIteratorT d_in, + OutputIteratorT d_out, + ScanTileStateT tile_state, + int start_tile, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items) { - using RealInitValueT = typename InitValueT::value_type; - typedef typename ChainedPolicyT::ActivePolicy::ScanPolicyT ScanPolicyT; - - // Thread block type for scanning input tiles - typedef AgentScan< - ScanPolicyT, - InputIteratorT, - OutputIteratorT, - ScanOpT, - RealInitValueT, - OffsetT> AgentScanT; - - // Shared memory for AgentScan - __shared__ typename AgentScanT::TempStorage temp_storage; - - RealInitValueT real_init_value = init_value; - - // Process tiles - AgentScanT(temp_storage, d_in, d_out, scan_op, real_init_value).ConsumeRange( - num_items, - tile_state, - start_tile); + using RealInitValueT = typename InitValueT::value_type; + typedef typename ChainedPolicyT::ActivePolicy::ScanPolicyT ScanPolicyT; + + // Thread block type for scanning input tiles + typedef AgentScan + AgentScanT; + + // Shared memory for AgentScan + __shared__ typename AgentScanT::TempStorage temp_storage; + + RealInitValueT real_init_value = init_value; + + // Process tiles + AgentScanT(temp_storage, d_in, d_out, scan_op, real_init_value) + .ConsumeRange(num_items, tile_state, start_tile); } - /****************************************************************************** * Policy ******************************************************************************/ -template < - typename OutputT> ///< Data type +template ///< Data type struct DeviceScanPolicy { - // For large values, use timesliced loads/stores to fit shared memory. - static constexpr bool LargeValues = sizeof(OutputT) > 128; - static constexpr BlockLoadAlgorithm ScanTransposedLoad = - LargeValues ? BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED - : BLOCK_LOAD_WARP_TRANSPOSE; - static constexpr BlockStoreAlgorithm ScanTransposedStore = - LargeValues ? BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED - : BLOCK_STORE_WARP_TRANSPOSE; - - /// SM350 - struct Policy350 : ChainedPolicy<350, Policy350, Policy350> - { - // GTX Titan: 29.5B items/s (232.4 GB/s) @ 48M 32-bit T - typedef AgentScanPolicy< - 128, 12, ///< Threads per block, items per thread - OutputT, - BLOCK_LOAD_DIRECT, - LOAD_CA, - BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED, - BLOCK_SCAN_RAKING> - ScanPolicyT; - }; - - /// SM520 - struct Policy520 : ChainedPolicy<520, Policy520, Policy350> - { - // Titan X: 32.47B items/s @ 48M 32-bit T - typedef AgentScanPolicy< - 128, 12, ///< Threads per block, items per thread - OutputT, - BLOCK_LOAD_DIRECT, - LOAD_CA, - ScanTransposedStore, - BLOCK_SCAN_WARP_SCANS> - ScanPolicyT; - }; - - /// SM600 - struct Policy600 : ChainedPolicy<600, Policy600, Policy520> - { - typedef AgentScanPolicy< - 128, 15, ///< Threads per block, items per thread - OutputT, - ScanTransposedLoad, - LOAD_DEFAULT, - ScanTransposedStore, - BLOCK_SCAN_WARP_SCANS> - ScanPolicyT; - }; - - /// MaxPolicy - typedef Policy600 MaxPolicy; + // For large values, use timesliced loads/stores to fit shared memory. + static constexpr bool LargeValues = sizeof(AccumT) > 128; + static constexpr BlockLoadAlgorithm ScanTransposedLoad = + LargeValues ? BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED + : BLOCK_LOAD_WARP_TRANSPOSE; + static constexpr BlockStoreAlgorithm ScanTransposedStore = + LargeValues ? BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED + : BLOCK_STORE_WARP_TRANSPOSE; + + /// SM350 + struct Policy350 : ChainedPolicy<350, Policy350, Policy350> + { + // GTX Titan: 29.5B items/s (232.4 GB/s) @ 48M 32-bit T + typedef AgentScanPolicy<128, + 12, ///< Threads per block, items per thread + AccumT, + BLOCK_LOAD_DIRECT, + LOAD_CA, + BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED, + BLOCK_SCAN_RAKING> + ScanPolicyT; + }; + + /// SM520 + struct Policy520 : ChainedPolicy<520, Policy520, Policy350> + { + // Titan X: 32.47B items/s @ 48M 32-bit T + typedef AgentScanPolicy<128, + 12, ///< Threads per block, items per thread + AccumT, + BLOCK_LOAD_DIRECT, + LOAD_CA, + ScanTransposedStore, + BLOCK_SCAN_WARP_SCANS> + ScanPolicyT; + }; + + /// SM600 + struct Policy600 : ChainedPolicy<600, Policy600, Policy520> + { + typedef AgentScanPolicy<128, + 15, ///< Threads per block, items per thread + AccumT, + ScanTransposedLoad, + LOAD_DEFAULT, + ScanTransposedStore, + BLOCK_SCAN_WARP_SCANS> + ScanPolicyT; + }; + + /// MaxPolicy + typedef Policy600 MaxPolicy; }; - /****************************************************************************** * Dispatch ******************************************************************************/ - /** - * Utility class for dispatching the appropriately-tuned kernels for DeviceScan + * @brief Utility class for dispatching the appropriately-tuned kernels for + * DeviceScan + * + * @tparam InputIteratorT + * Random-access input iterator type for reading scan inputs \iterator + * + * @tparam OutputIteratorT + * Random-access output iterator type for writing scan outputs \iterator + * + * @tparam ScanOpT + * Binary scan functor type having member + * `auto operator()(const T &a, const U &b)` + * + * @tparam InitValueT + * The init_value element type for ScanOpT (cub::NullType for inclusive scans) + * + * @tparam OffsetT + * Signed integer type for global offsets + * */ -template < - typename InputIteratorT, ///< Random-access input iterator type for reading scan inputs \iterator - typename OutputIteratorT, ///< Random-access output iterator type for writing scan outputs \iterator - typename ScanOpT, ///< Binary scan functor type having member T operator()(const T &a, const T &b) - typename InitValueT, ///< The init_value element type for ScanOpT (cub::NullType for inclusive scans) - typename OffsetT, ///< Signed integer type for global offsets - typename SelectedPolicy = DeviceScanPolicy< - // Accumulator type. - cub::detail::conditional_t::value, - cub::detail::value_t, - typename InitValueT::value_type>>> -struct DispatchScan: - SelectedPolicy +template ::value, + cub::detail::value_t, + typename InitValueT::value_type>, + cub::detail::value_t>, + typename SelectedPolicy = DeviceScanPolicy> +struct DispatchScan : SelectedPolicy { - //--------------------------------------------------------------------- - // Constants and Types - //--------------------------------------------------------------------- - - enum + //--------------------------------------------------------------------- + // Constants and Types + //--------------------------------------------------------------------- + + static constexpr int INIT_KERNEL_THREADS = 128; + + // The input value type + using InputT = cub::detail::value_t; + + /// Device-accessible allocation of temporary storage. When NULL, the + /// required allocation size is written to \p temp_storage_bytes and no work + /// is done. + void *d_temp_storage; + + /// Reference to size in bytes of \p d_temp_storage allocation + size_t &temp_storage_bytes; + + /// Iterator to the input sequence of data items + InputIteratorT d_in; + + /// Iterator to the output sequence of data items + OutputIteratorT d_out; + + /// Binary scan functor + ScanOpT scan_op; + + /// Initial value to seed the exclusive scan + InitValueT init_value; + + /// Total number of input items (i.e., the length of \p d_in) + OffsetT num_items; + + /// CUDA stream to launch kernels within. Default is stream0. + cudaStream_t stream; + + int ptx_version; + + /** + * + * @param[in] d_temp_storage + * Device-accessible allocation of temporary storage. When `nullptr`, the + * required allocation size is written to `temp_storage_bytes` and no + * work is done. + * + * @param[in,out] temp_storage_bytes + * Reference to size in bytes of `d_temp_storage` allocation + * + * @param[in] d_in + * Iterator to the input sequence of data items + * + * @param[out] d_out + * Iterator to the output sequence of data items + * + * @param[in] num_items + * Total number of input items (i.e., the length of `d_in`) + * + * @param[in] scan_op + * Binary scan functor + * + * @param[in] init_value + * Initial value to seed the exclusive scan + * + * @param[in] stream + * **[optional]** CUDA stream to launch kernels within. + * Default is stream0. + */ + CUB_RUNTIME_FUNCTION __forceinline__ DispatchScan(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + OffsetT num_items, + ScanOpT scan_op, + InitValueT init_value, + cudaStream_t stream, + int ptx_version) + : d_temp_storage(d_temp_storage) + , temp_storage_bytes(temp_storage_bytes) + , d_in(d_in) + , d_out(d_out) + , scan_op(scan_op) + , init_value(init_value) + , num_items(num_items) + , stream(stream) + , ptx_version(ptx_version) + {} + + CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED + CUB_RUNTIME_FUNCTION __forceinline__ DispatchScan(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + OffsetT num_items, + ScanOpT scan_op, + InitValueT init_value, + cudaStream_t stream, + bool debug_synchronous, + int ptx_version) + : d_temp_storage(d_temp_storage) + , temp_storage_bytes(temp_storage_bytes) + , d_in(d_in) + , d_out(d_out) + , scan_op(scan_op) + , init_value(init_value) + , num_items(num_items) + , stream(stream) + , ptx_version(ptx_version) + { + CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG + } + + template + CUB_RUNTIME_FUNCTION __host__ __forceinline__ cudaError_t + Invoke(InitKernel init_kernel, ScanKernel scan_kernel) + { + typedef typename ActivePolicyT::ScanPolicyT Policy; + typedef typename cub::ScanTileState ScanTileStateT; + + // `LOAD_LDG` makes in-place execution UB and doesn't lead to better + // performance. + static_assert(Policy::LOAD_MODIFIER != CacheLoadModifier::LOAD_LDG, + "The memory consistency model does not apply to texture " + "accesses"); + + cudaError error = cudaSuccess; + do { - INIT_KERNEL_THREADS = 128 - }; - - // The input value type - using InputT = cub::detail::value_t; - - // The output value type -- used as the intermediate accumulator - // Per https://wg21.link/P0571, use InitValueT::value_type if provided, otherwise the - // input iterator's value type. - using OutputT = - cub::detail::conditional_t::value, - InputT, - typename InitValueT::value_type>; - - void* d_temp_storage; ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. - size_t& temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - InputIteratorT d_in; ///< [in] Iterator to the input sequence of data items - OutputIteratorT d_out; ///< [out] Iterator to the output sequence of data items - ScanOpT scan_op; ///< [in] Binary scan functor - InitValueT init_value; ///< [in] Initial value to seed the exclusive scan - OffsetT num_items; ///< [in] Total number of input items (i.e., the length of \p d_in) - cudaStream_t stream; ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. - int ptx_version; - - CUB_RUNTIME_FUNCTION __forceinline__ - DispatchScan( - void* d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. - size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - InputIteratorT d_in, ///< [in] Iterator to the input sequence of data items - OutputIteratorT d_out, ///< [out] Iterator to the output sequence of data items - OffsetT num_items, ///< [in] Total number of input items (i.e., the length of \p d_in) - ScanOpT scan_op, ///< [in] Binary scan functor - InitValueT init_value, ///< [in] Initial value to seed the exclusive scan - cudaStream_t stream, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. - int ptx_version - ): - d_temp_storage(d_temp_storage), - temp_storage_bytes(temp_storage_bytes), - d_in(d_in), - d_out(d_out), - scan_op(scan_op), - init_value(init_value), - num_items(num_items), - stream(stream), - ptx_version(ptx_version) - {} - - CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED - CUB_RUNTIME_FUNCTION __forceinline__ - DispatchScan(void *d_temp_storage, - size_t &temp_storage_bytes, - InputIteratorT d_in, - OutputIteratorT d_out, - OffsetT num_items, - ScanOpT scan_op, - InitValueT init_value, - cudaStream_t stream, - bool debug_synchronous, - int ptx_version) - : d_temp_storage(d_temp_storage) - , temp_storage_bytes(temp_storage_bytes) - , d_in(d_in) - , d_out(d_out) - , scan_op(scan_op) - , init_value(init_value) - , num_items(num_items) - , stream(stream) - , ptx_version(ptx_version) - { - CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG - } - - template - CUB_RUNTIME_FUNCTION __host__ __forceinline__ - cudaError_t Invoke(InitKernel init_kernel, ScanKernel scan_kernel) - { - typedef typename ActivePolicyT::ScanPolicyT Policy; - typedef typename cub::ScanTileState ScanTileStateT; - - // `LOAD_LDG` makes in-place execution UB and doesn't lead to better - // performance. - static_assert( - Policy::LOAD_MODIFIER != CacheLoadModifier::LOAD_LDG, - "The memory consistency model does not apply to texture accesses"); + // Get device ordinal + int device_ordinal; + if (CubDebug(error = cudaGetDevice(&device_ordinal))) + { + break; + } + + // Number of input tiles + int tile_size = Policy::BLOCK_THREADS * Policy::ITEMS_PER_THREAD; + int num_tiles = + static_cast(cub::DivideAndRoundUp(num_items, tile_size)); + + // Specify temporary storage allocation requirements + size_t allocation_sizes[1]; + if (CubDebug(error = ScanTileStateT::AllocationSize(num_tiles, + allocation_sizes[0]))) + { + break; // bytes needed for tile status descriptors + } + + // Compute allocation pointers into the single storage blob (or compute + // the necessary size of the blob) + void *allocations[1] = {}; + if (CubDebug(error = AliasTemporaries(d_temp_storage, + temp_storage_bytes, + allocations, + allocation_sizes))) + { + break; + } + + if (d_temp_storage == NULL) + { + // Return if the caller is simply requesting the size of the storage + // allocation + break; + } + + // Return if empty problem + if (num_items == 0) + { + break; + } + + // Construct the tile status interface + ScanTileStateT tile_state; + if (CubDebug(error = tile_state.Init(num_tiles, + allocations[0], + allocation_sizes[0]))) + { + break; + } + + // Log init_kernel configuration + int init_grid_size = cub::DivideAndRoundUp(num_tiles, + INIT_KERNEL_THREADS); + + #ifdef CUB_DETAIL_DEBUG_ENABLE_LOG + _CubLog("Invoking init_kernel<<<%d, %d, 0, %lld>>>()\n", + init_grid_size, + INIT_KERNEL_THREADS, + (long long)stream); + #endif + + // Invoke init_kernel to initialize tile descriptors + THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( + init_grid_size, + INIT_KERNEL_THREADS, + 0, + stream) + .doit(init_kernel, tile_state, num_tiles); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) + { + break; + } + + // Sync the stream if specified to flush runtime errors + error = detail::DebugSyncStream(stream); + if (CubDebug(error)) + { + break; + } + + // Get SM occupancy for scan_kernel + int scan_sm_occupancy; + if (CubDebug(error = MaxSmOccupancy(scan_sm_occupancy, // out + scan_kernel, + Policy::BLOCK_THREADS))) + { + break; + } + + // Get max x-dimension of grid + int max_dim_x; + if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, + cudaDevAttrMaxGridDimX, + device_ordinal))) + { + break; + } + + // Run grids in epochs (in case number of tiles exceeds max x-dimension + int scan_grid_size = CUB_MIN(num_tiles, max_dim_x); + for (int start_tile = 0; start_tile < num_tiles; + start_tile += scan_grid_size) + { + // Log scan_kernel configuration + #ifdef CUB_DETAIL_DEBUG_ENABLE_LOG + _CubLog("Invoking %d scan_kernel<<<%d, %d, 0, %lld>>>(), %d items " + "per thread, %d SM occupancy\n", + start_tile, + scan_grid_size, + Policy::BLOCK_THREADS, + (long long)stream, + Policy::ITEMS_PER_THREAD, + scan_sm_occupancy); + #endif + + // Invoke scan_kernel + THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( + scan_grid_size, + Policy::BLOCK_THREADS, + 0, + stream) + .doit(scan_kernel, + d_in, + d_out, + tile_state, + start_tile, + scan_op, + init_value, + num_items); - cudaError error = cudaSuccess; - do + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) { - // Get device ordinal - int device_ordinal; - if (CubDebug(error = cudaGetDevice(&device_ordinal))) break; - - // Number of input tiles - int tile_size = Policy::BLOCK_THREADS * Policy::ITEMS_PER_THREAD; - int num_tiles = static_cast(cub::DivideAndRoundUp(num_items, tile_size)); - - // Specify temporary storage allocation requirements - size_t allocation_sizes[1]; - if (CubDebug(error = ScanTileStateT::AllocationSize(num_tiles, allocation_sizes[0]))) break; // bytes needed for tile status descriptors - - // Compute allocation pointers into the single storage blob (or compute the necessary size of the blob) - void* allocations[1] = {}; - if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; - if (d_temp_storage == NULL) - { - // Return if the caller is simply requesting the size of the storage allocation - break; - } - - // Return if empty problem - if (num_items == 0) - break; - - // Construct the tile status interface - ScanTileStateT tile_state; - if (CubDebug(error = tile_state.Init(num_tiles, allocations[0], allocation_sizes[0]))) break; - - // Log init_kernel configuration - int init_grid_size = cub::DivideAndRoundUp(num_tiles, INIT_KERNEL_THREADS); - - #ifdef CUB_DETAIL_DEBUG_ENABLE_LOG - _CubLog("Invoking init_kernel<<<%d, %d, 0, %lld>>>()\n", init_grid_size, INIT_KERNEL_THREADS, (long long) stream); - #endif - - // Invoke init_kernel to initialize tile descriptors - THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( - init_grid_size, INIT_KERNEL_THREADS, 0, stream - ).doit(init_kernel, - tile_state, - num_tiles); - - // Check for failure to launch - if (CubDebug(error = cudaPeekAtLastError())) - { - break; - } - - // Sync the stream if specified to flush runtime errors - error = detail::DebugSyncStream(stream); - if (CubDebug(error)) - { - break; - } - - // Get SM occupancy for scan_kernel - int scan_sm_occupancy; - if (CubDebug(error = MaxSmOccupancy( - scan_sm_occupancy, // out - scan_kernel, - Policy::BLOCK_THREADS))) break; - - // Get max x-dimension of grid - int max_dim_x; - if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, cudaDevAttrMaxGridDimX, device_ordinal))) break; - - // Run grids in epochs (in case number of tiles exceeds max x-dimension - int scan_grid_size = CUB_MIN(num_tiles, max_dim_x); - for (int start_tile = 0; start_tile < num_tiles; start_tile += scan_grid_size) - { - // Log scan_kernel configuration - #ifdef CUB_DETAIL_DEBUG_ENABLE_LOG - _CubLog("Invoking %d scan_kernel<<<%d, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy\n", - start_tile, scan_grid_size, Policy::BLOCK_THREADS, (long long) stream, Policy::ITEMS_PER_THREAD, scan_sm_occupancy); - #endif - - // Invoke scan_kernel - THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( - scan_grid_size, Policy::BLOCK_THREADS, 0, stream - ).doit(scan_kernel, - d_in, - d_out, - tile_state, - start_tile, - scan_op, - init_value, - num_items); - - // Check for failure to launch - if (CubDebug(error = cudaPeekAtLastError())) - { - break; - } - - // Sync the stream if specified to flush runtime errors - error = detail::DebugSyncStream(stream); - if (CubDebug(error)) - { - break; - } - } + break; } - while (0); - - return error; - } - template - CUB_RUNTIME_FUNCTION __host__ __forceinline__ - cudaError_t Invoke() - { - typedef typename DispatchScan::MaxPolicy MaxPolicyT; - typedef typename cub::ScanTileState ScanTileStateT; - // Ensure kernels are instantiated. - return Invoke( - DeviceScanInitKernel, - DeviceScanKernel - ); - } - - - /** - * Internal dispatch routine - */ - CUB_RUNTIME_FUNCTION __forceinline__ - static cudaError_t Dispatch( - void* d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. - size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - InputIteratorT d_in, ///< [in] Iterator to the input sequence of data items - OutputIteratorT d_out, ///< [out] Iterator to the output sequence of data items - ScanOpT scan_op, ///< [in] Binary scan functor - InitValueT init_value, ///< [in] Initial value to seed the exclusive scan - OffsetT num_items, ///< [in] Total number of input items (i.e., the length of \p d_in) - cudaStream_t stream) ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. - { - typedef typename DispatchScan::MaxPolicy MaxPolicyT; - - cudaError_t error; - do + // Sync the stream if specified to flush runtime errors + error = detail::DebugSyncStream(stream); + if (CubDebug(error)) { - // Get PTX version - int ptx_version = 0; - if (CubDebug(error = PtxVersion(ptx_version))) break; - - // Create dispatch functor - DispatchScan dispatch( - d_temp_storage, - temp_storage_bytes, - d_in, - d_out, - num_items, - scan_op, - init_value, - stream, - ptx_version - ); - // Dispatch to chained policy - if (CubDebug(error = MaxPolicyT::Invoke(ptx_version, dispatch))) break; + break; } - while (0); - - return error; - } - - CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED - CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t - Dispatch(void *d_temp_storage, - size_t &temp_storage_bytes, - InputIteratorT d_in, - OutputIteratorT d_out, - ScanOpT scan_op, - InitValueT init_value, - OffsetT num_items, - cudaStream_t stream, - bool debug_synchronous) + } + } while (0); + + return error; + } + + template + CUB_RUNTIME_FUNCTION __host__ __forceinline__ cudaError_t Invoke() + { + typedef typename DispatchScan::MaxPolicy MaxPolicyT; + typedef typename cub::ScanTileState ScanTileStateT; + // Ensure kernels are instantiated. + return Invoke(DeviceScanInitKernel, + DeviceScanKernel); + } + + /** + * @brief Internal dispatch routine + * + * @param[in] d_temp_storage + * Device-accessible allocation of temporary storage. When `nullptr`, the + * required allocation size is written to `temp_storage_bytes` and no + * work is done. + * + * @param[in,out] temp_storage_bytes + * Reference to size in bytes of `d_temp_storage` allocation + * + * @param[in] d_in + * Iterator to the input sequence of data items + * + * @param[out] d_out + * Iterator to the output sequence of data items + * + * @param[in] scan_op + * Binary scan functor + * + * @param[in] init_value + * Initial value to seed the exclusive scan + * + * @param[in] num_items + * Total number of input items (i.e., the length of `d_in`) + * + * @param[in] stream + * **[optional]** CUDA stream to launch kernels within. + * Default is stream0. + * + */ + CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t + Dispatch(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items, + cudaStream_t stream) + { + typedef typename DispatchScan::MaxPolicy MaxPolicyT; + + cudaError_t error; + do { - CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG - - return Dispatch(d_temp_storage, - temp_storage_bytes, - d_in, - d_out, - scan_op, - init_value, - num_items, - stream); - } + // Get PTX version + int ptx_version = 0; + if (CubDebug(error = PtxVersion(ptx_version))) + { + break; + } + + // Create dispatch functor + DispatchScan dispatch(d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + num_items, + scan_op, + init_value, + stream, + ptx_version); + + // Dispatch to chained policy + if (CubDebug(error = MaxPolicyT::Invoke(ptx_version, dispatch))) + { + break; + } + } while (0); + + return error; + } + + CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED + CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t + Dispatch(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items, + cudaStream_t stream, + bool debug_synchronous) + { + CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG + + return Dispatch(d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + scan_op, + init_value, + num_items, + stream); + } }; - - CUB_NAMESPACE_END + diff --git a/cub/device/dispatch/dispatch_scan_by_key.cuh b/cub/device/dispatch/dispatch_scan_by_key.cuh index ba787705f3..1a4cfe7c4f 100644 --- a/cub/device/dispatch/dispatch_scan_by_key.cuh +++ b/cub/device/dispatch/dispatch_scan_by_key.cuh @@ -12,10 +12,10 @@ * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND @@ -26,88 +26,142 @@ ******************************************************************************/ /** - * \file - * DeviceScan provides device-wide, parallel operations for computing a prefix scan across a sequence of data items residing within device-accessible memory. + * @file DeviceScan provides device-wide, parallel operations for computing a + * prefix scan across a sequence of data items residing within + * device-accessible memory. */ #pragma once #include -#include "../../agent/agent_scan_by_key.cuh" -#include "../../thread/thread_operators.cuh" -#include "../../config.cuh" -#include "../../util_debug.cuh" -#include "../../util_device.cuh" -#include "../../util_math.cuh" -#include "dispatch_scan.cuh" +#include +#include +#include +#include +#include +#include +#include +#include #include CUB_NAMESPACE_BEGIN - /****************************************************************************** -* Kernel entry points -*****************************************************************************/ + * Kernel entry points + *****************************************************************************/ /** - * Scan kernel entry point (multi-block) + * @brief Scan kernel entry point (multi-block) + * + * @tparam ChainedPolicyT + * Chained tuning policy + * + * @tparam KeysInputIteratorT + * Random-access input iterator type + * + * @tparam ValuesInputIteratorT + * Random-access input iterator type + * + * @tparam ValuesOutputIteratorT + * Random-access output iterator type + * + * @tparam ScanByKeyTileStateT + * Tile status interface type + * + * @tparam EqualityOp + * Equality functor type + * + * @tparam ScanOpT + * Scan functor type + * + * @tparam InitValueT + * The init_value element for ScanOpT type (cub::NullType for inclusive scan) + * + * @tparam OffsetT + * Signed integer type for global offsets + * + * @param d_keys_in + * Input keys data + * + * @param d_keys_prev_in + * Predecessor items for each tile + * + * @param d_values_in + * Input values data + * + * @param d_values_out + * Output values data + * + * @param tile_state + * Tile status interface + * + * @param start_tile + * The starting tile for the current grid + * + * @param equality_op + * Binary equality functor + * + * @param scan_op + * Binary scan functor + * + * @param init_value + * Initial value to seed the exclusive scan + * + * @param num_items + * Total number of scan items for the entire problem */ -template < - typename ChainedPolicyT, ///< Chained tuning policy - typename KeysInputIteratorT, ///< Random-access input iterator type - typename ValuesInputIteratorT, ///< Random-access input iterator type - typename ValuesOutputIteratorT, ///< Random-access output iterator type - typename ScanByKeyTileStateT, ///< Tile status interface type - typename EqualityOp, ///< Equality functor type - typename ScanOpT, ///< Scan functor type - typename InitValueT, ///< The init_value element for ScanOpT type (cub::NullType for inclusive scan) - typename OffsetT, ///< Signed integer type for global offsets - typename KeyT = cub::detail::value_t> -__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::ScanByKeyPolicyT::BLOCK_THREADS)) -__global__ void DeviceScanByKeyKernel( - KeysInputIteratorT d_keys_in, ///< Input keys data - KeyT *d_keys_prev_in, ///< Predecessor items for each tile - ValuesInputIteratorT d_values_in, ///< Input values data - ValuesOutputIteratorT d_values_out, ///< Output values data - ScanByKeyTileStateT tile_state, ///< Tile status interface - int start_tile, ///< The starting tile for the current grid - EqualityOp equality_op, ///< Binary equality functor - ScanOpT scan_op, ///< Binary scan functor - InitValueT init_value, ///< Initial value to seed the exclusive scan - OffsetT num_items) ///< Total number of scan items for the entire problem +template > +__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanByKeyPolicyT::BLOCK_THREADS)) +__global__ void DeviceScanByKeyKernel(KeysInputIteratorT d_keys_in, + KeyT *d_keys_prev_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + ScanByKeyTileStateT tile_state, + int start_tile, + EqualityOp equality_op, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items) { - typedef typename ChainedPolicyT::ActivePolicy::ScanByKeyPolicyT ScanByKeyPolicyT; - - // Thread block type for scanning input tiles - typedef AgentScanByKey< - ScanByKeyPolicyT, - KeysInputIteratorT, - ValuesInputIteratorT, - ValuesOutputIteratorT, - EqualityOp, - ScanOpT, - InitValueT, - OffsetT> AgentScanByKeyT; - - // Shared memory for AgentScanByKey - __shared__ typename AgentScanByKeyT::TempStorage temp_storage; - - // Process tiles - AgentScanByKeyT( - temp_storage, - d_keys_in, - d_keys_prev_in, - d_values_in, - d_values_out, - equality_op, - scan_op, - init_value - ).ConsumeRange( - num_items, - tile_state, - start_tile); + using ScanByKeyPolicyT = + typename ChainedPolicyT::ActivePolicy::ScanByKeyPolicyT; + + // Thread block type for scanning input tiles + using AgentScanByKeyT = AgentScanByKey; + + // Shared memory for AgentScanByKey + __shared__ typename AgentScanByKeyT::TempStorage temp_storage; + + // Process tiles + AgentScanByKeyT(temp_storage, + d_keys_in, + d_keys_prev_in, + d_values_in, + d_values_out, + equality_op, + scan_op, + init_value) + .ConsumeRange(num_items, tile_state, start_tile); } template @@ -135,390 +189,544 @@ __global__ void DeviceScanByKeyInitKernel( ******************************************************************************/ template + typename AccumT> struct DeviceScanByKeyPolicy { - using KeyT = cub::detail::value_t; - using ValueT = cub::detail::conditional_t< - std::is_same::value, - cub::detail::value_t, - InitValueT>; - static constexpr size_t MaxInputBytes = (sizeof(KeyT) > sizeof(ValueT) ? sizeof(KeyT) : sizeof(ValueT)); - static constexpr size_t CombinedInputBytes = sizeof(KeyT) + sizeof(ValueT); + using KeyT = cub::detail::value_t; - // SM350 - struct Policy350 : ChainedPolicy<350, Policy350, Policy350> - { - enum - { - NOMINAL_4B_ITEMS_PER_THREAD = 6, - ITEMS_PER_THREAD = ((MaxInputBytes <= 8) ? 6 : - Nominal4BItemsToItemsCombined(NOMINAL_4B_ITEMS_PER_THREAD, CombinedInputBytes)), - }; - - typedef AgentScanByKeyPolicy< - 128, ITEMS_PER_THREAD, - BLOCK_LOAD_WARP_TRANSPOSE, - LOAD_CA, - BLOCK_SCAN_WARP_SCANS, - BLOCK_STORE_WARP_TRANSPOSE> - ScanByKeyPolicyT; - }; - - // SM520 - struct Policy520 : ChainedPolicy<520, Policy520, Policy350> - { - enum - { - NOMINAL_4B_ITEMS_PER_THREAD = 9, - - ITEMS_PER_THREAD = ((MaxInputBytes <= 8) ? 9 : - Nominal4BItemsToItemsCombined(NOMINAL_4B_ITEMS_PER_THREAD, CombinedInputBytes)), - }; - - typedef AgentScanByKeyPolicy< - 256, ITEMS_PER_THREAD, - BLOCK_LOAD_WARP_TRANSPOSE, - LOAD_CA, - BLOCK_SCAN_WARP_SCANS, - BLOCK_STORE_WARP_TRANSPOSE> - ScanByKeyPolicyT; - }; - - /// MaxPolicy - typedef Policy520 MaxPolicy; -}; + static constexpr size_t MaxInputBytes = (cub::max)(sizeof(KeyT), + sizeof(AccumT)); + + static constexpr size_t CombinedInputBytes = sizeof(KeyT) + sizeof(AccumT); + // SM350 + struct Policy350 : ChainedPolicy<350, Policy350, Policy350> + { + static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 6; + static constexpr int ITEMS_PER_THREAD = + ((MaxInputBytes <= 8) + ? 6 + : Nominal4BItemsToItemsCombined(NOMINAL_4B_ITEMS_PER_THREAD, + CombinedInputBytes)); + + using ScanByKeyPolicyT = AgentScanByKeyPolicy<128, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_CA, + BLOCK_SCAN_WARP_SCANS, + BLOCK_STORE_WARP_TRANSPOSE>; + }; + + // SM520 + struct Policy520 : ChainedPolicy<520, Policy520, Policy350> + { + static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 9; + static constexpr int ITEMS_PER_THREAD = + ((MaxInputBytes <= 8) + ? 9 + : Nominal4BItemsToItemsCombined(NOMINAL_4B_ITEMS_PER_THREAD, + CombinedInputBytes)); + + using ScanByKeyPolicyT = AgentScanByKeyPolicy<256, + ITEMS_PER_THREAD, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_CA, + BLOCK_SCAN_WARP_SCANS, + BLOCK_STORE_WARP_TRANSPOSE>; + }; + + using MaxPolicy = Policy520; +}; /****************************************************************************** * Dispatch ******************************************************************************/ - /** - * Utility class for dispatching the appropriately-tuned kernels for DeviceScan + * @brief Utility class for dispatching the appropriately-tuned kernels + * for DeviceScan + * + * @tparam KeysInputIteratorT + * Random-access input iterator type + * + * @tparam ValuesInputIteratorT + * Random-access input iterator type + * + * @tparam ValuesOutputIteratorT + * Random-access output iterator type + * + * @tparam EqualityOp + * Equality functor type + * + * @tparam ScanOpT + * Scan functor type + * + * @tparam InitValueT + * The init_value element for ScanOpT type (cub::NullType for inclusive scan) + * + * @tparam OffsetT + * Signed integer type for global offsets + * */ template < - typename KeysInputIteratorT, ///< Random-access input iterator type - typename ValuesInputIteratorT, ///< Random-access input iterator type - typename ValuesOutputIteratorT, ///< Random-access output iterator type - typename EqualityOp, ///< Equality functor type - typename ScanOpT, ///< Scan functor type - typename InitValueT, ///< The init_value element for ScanOpT type (cub::NullType for inclusive scan) - typename OffsetT, ///< Signed integer type for global offsets - typename SelectedPolicy = DeviceScanByKeyPolicy -> -struct DispatchScanByKey: - SelectedPolicy + typename KeysInputIteratorT, + typename ValuesInputIteratorT, + typename ValuesOutputIteratorT, + typename EqualityOp, + typename ScanOpT, + typename InitValueT, + typename OffsetT, + typename AccumT = + detail::accumulator_t< + ScanOpT, + cub::detail::conditional_t< + std::is_same::value, + cub::detail::value_t, + InitValueT>, + cub::detail::value_t>, + typename SelectedPolicy = + DeviceScanByKeyPolicy> +struct DispatchScanByKey : SelectedPolicy { - //--------------------------------------------------------------------- - // Constants and Types - //--------------------------------------------------------------------- - - enum - { - INIT_KERNEL_THREADS = 128 - }; - - // The input key type - using KeyT = cub::detail::value_t; - - // The input value type - using InputT = cub::detail::value_t; - - // The output value type -- used as the intermediate accumulator - // Per https://wg21.link/P0571, use InitValueT if provided, otherwise the - // input iterator's value type. - using OutputT = - cub::detail::conditional_t::value, - InputT, - InitValueT>; - - void* d_temp_storage; ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. - size_t& temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - KeysInputIteratorT d_keys_in; ///< [in] Iterator to the input sequence of key items - ValuesInputIteratorT d_values_in; ///< [in] Iterator to the input sequence of value items - ValuesOutputIteratorT d_values_out; ///< [out] Iterator to the input sequence of value items - EqualityOp equality_op; ///< [in]Binary equality functor - ScanOpT scan_op; ///< [in] Binary scan functor - InitValueT init_value; ///< [in] Initial value to seed the exclusive scan - OffsetT num_items; ///< [in] Total number of input items (i.e., the length of \p d_in) - cudaStream_t stream; ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. - int ptx_version; - - CUB_RUNTIME_FUNCTION __forceinline__ - DispatchScanByKey( - void* d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. - size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - KeysInputIteratorT d_keys_in, ///< [in] Iterator to the input sequence of key items - ValuesInputIteratorT d_values_in, ///< [in] Iterator to the input sequence of value items - ValuesOutputIteratorT d_values_out, ///< [out] Iterator to the input sequence of value items - EqualityOp equality_op, ///< [in] Binary equality functor - ScanOpT scan_op, ///< [in] Binary scan functor - InitValueT init_value, ///< [in] Initial value to seed the exclusive scan - OffsetT num_items, ///< [in] Total number of input items (i.e., the length of \p d_in) - cudaStream_t stream, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. - int ptx_version - ): - d_temp_storage(d_temp_storage), - temp_storage_bytes(temp_storage_bytes), - d_keys_in(d_keys_in), - d_values_in(d_values_in), - d_values_out(d_values_out), - equality_op(equality_op), - scan_op(scan_op), - init_value(init_value), - num_items(num_items), - stream(stream), - ptx_version(ptx_version) - {} - - CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED - CUB_RUNTIME_FUNCTION __forceinline__ - DispatchScanByKey(void *d_temp_storage, - size_t &temp_storage_bytes, - KeysInputIteratorT d_keys_in, - ValuesInputIteratorT d_values_in, - ValuesOutputIteratorT d_values_out, - EqualityOp equality_op, - ScanOpT scan_op, - InitValueT init_value, - OffsetT num_items, - cudaStream_t stream, - bool debug_synchronous, - int ptx_version) - : d_temp_storage(d_temp_storage) - , temp_storage_bytes(temp_storage_bytes) - , d_keys_in(d_keys_in) - , d_values_in(d_values_in) - , d_values_out(d_values_out) - , equality_op(equality_op) - , scan_op(scan_op) - , init_value(init_value) - , num_items(num_items) - , stream(stream) - , ptx_version(ptx_version) - { - CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG - } - - template - CUB_RUNTIME_FUNCTION __host__ __forceinline__ - cudaError_t Invoke(InitKernel init_kernel, ScanKernel scan_kernel) - { - typedef typename ActivePolicyT::ScanByKeyPolicyT Policy; - typedef ReduceByKeyScanTileState ScanByKeyTileStateT; - - cudaError error = cudaSuccess; - do - { - // Get device ordinal - int device_ordinal; - if (CubDebug(error = cudaGetDevice(&device_ordinal))) break; - - // Number of input tiles - int tile_size = Policy::BLOCK_THREADS * Policy::ITEMS_PER_THREAD; - int num_tiles = static_cast(cub::DivideAndRoundUp(num_items, tile_size)); - - // Specify temporary storage allocation requirements - size_t allocation_sizes[2]; - if (CubDebug(error = ScanByKeyTileStateT::AllocationSize(num_tiles, allocation_sizes[0]))) break; // bytes needed for tile status descriptors - - allocation_sizes[1] = sizeof(KeyT) * (num_tiles + 1); - - // Compute allocation pointers into the single storage blob (or compute the necessary size of the blob) - void* allocations[2] = {}; - if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; - if (d_temp_storage == NULL) - { - // Return if the caller is simply requesting the size of the storage allocation - break; - } - - // Return if empty problem - if (num_items == 0) - break; - - KeyT *d_keys_prev_in = reinterpret_cast(allocations[1]); - - // Construct the tile status interface - ScanByKeyTileStateT tile_state; - if (CubDebug(error = tile_state.Init(num_tiles, allocations[0], allocation_sizes[0]))) break; - - // Log init_kernel configuration - int init_grid_size = cub::DivideAndRoundUp(num_tiles, INIT_KERNEL_THREADS); - - #ifdef CUB_DETAIL_DEBUG_ENABLE_LOG - _CubLog("Invoking init_kernel<<<%d, %d, 0, %lld>>>()\n", init_grid_size, INIT_KERNEL_THREADS, (long long) stream); - #endif - - // Invoke init_kernel to initialize tile descriptors - THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( - init_grid_size, INIT_KERNEL_THREADS, 0, stream - ).doit(init_kernel, tile_state, d_keys_in, d_keys_prev_in, tile_size, num_tiles); - - // Check for failure to launch - if (CubDebug(error = cudaPeekAtLastError())) - { - break; - } - - // Sync the stream if specified to flush runtime errors - error = detail::DebugSyncStream(stream); - if (CubDebug(error)) - { - break; - } - - // Get SM occupancy for scan_kernel - int scan_sm_occupancy; - if (CubDebug(error = MaxSmOccupancy( - scan_sm_occupancy, // out - scan_kernel, - Policy::BLOCK_THREADS))) break; - - // Get max x-dimension of grid - int max_dim_x; - if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, cudaDevAttrMaxGridDimX, device_ordinal))) break; - - // Run grids in epochs (in case number of tiles exceeds max x-dimension - int scan_grid_size = CUB_MIN(num_tiles, max_dim_x); - for (int start_tile = 0; start_tile < num_tiles; start_tile += scan_grid_size) - { - // Log scan_kernel configuration - #ifdef CUB_DETAIL_DEBUG_ENABLE_LOG - _CubLog("Invoking %d scan_kernel<<<%d, %d, 0, %lld>>>(), %d items per thread, %d SM occupancy\n", - start_tile, scan_grid_size, Policy::BLOCK_THREADS, (long long) stream, Policy::ITEMS_PER_THREAD, scan_sm_occupancy); - #endif - - // Invoke scan_kernel - THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( - scan_grid_size, Policy::BLOCK_THREADS, 0, stream - ).doit( - scan_kernel, - d_keys_in, - d_keys_prev_in, - d_values_in, - d_values_out, - tile_state, - start_tile, - equality_op, - scan_op, - init_value, - num_items); - - // Check for failure to launch - if (CubDebug(error = cudaPeekAtLastError())) - { - break; - } - - // Sync the stream if specified to flush runtime errors - error = detail::DebugSyncStream(stream); - if (CubDebug(error)) - { - break; - } - } - } - while (0); + //--------------------------------------------------------------------- + // Constants and Types + //--------------------------------------------------------------------- + + static constexpr int INIT_KERNEL_THREADS = 128; + + // The input key type + using KeyT = cub::detail::value_t; + + // The input value type + using InputT = cub::detail::value_t; + + /// Device-accessible allocation of temporary storage. When `nullptr`, the + /// required allocation size is written to `temp_storage_bytes` and no work + /// is done. + void *d_temp_storage; + + /// Reference to size in bytes of `d_temp_storage` allocation + size_t &temp_storage_bytes; + + /// Iterator to the input sequence of key items + KeysInputIteratorT d_keys_in; + + /// Iterator to the input sequence of value items + ValuesInputIteratorT d_values_in; + + /// Iterator to the input sequence of value items + ValuesOutputIteratorT d_values_out; + + /// Binary equality functor + EqualityOp equality_op; + + /// Binary scan functor + ScanOpT scan_op; + + /// Initial value to seed the exclusive scan + InitValueT init_value; + + /// Total number of input items (i.e., the length of `d_in`) + OffsetT num_items; + + /// CUDA stream to launch kernels within. + cudaStream_t stream; + int ptx_version; + + /** + * @param[in] d_temp_storage + * Device-accessible allocation of temporary storage. When `nullptr`, the + * required allocation size is written to `temp_storage_bytes` and no + * work is done. + * + * @param[in,out] temp_storage_bytes + * Reference to size in bytes of `d_temp_storage` allocation + * + * @param[in] d_keys_in + * Iterator to the input sequence of key items + * + * @param[in] d_values_in + * Iterator to the input sequence of value items + * + * @param[out] d_values_out + * Iterator to the input sequence of value items + * + * @param[in] equality_op + * Binary equality functor + * + * @param[in] scan_op + * Binary scan functor + * + * @param[in] init_value + * Initial value to seed the exclusive scan + * + * @param[in] num_items + * Total number of input items (i.e., the length of `d_in`) + * + * @param[in] stream + * CUDA stream to launch kernels within. + */ + CUB_RUNTIME_FUNCTION __forceinline__ + DispatchScanByKey(void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + EqualityOp equality_op, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items, + cudaStream_t stream, + int ptx_version) + : d_temp_storage(d_temp_storage) + , temp_storage_bytes(temp_storage_bytes) + , d_keys_in(d_keys_in) + , d_values_in(d_values_in) + , d_values_out(d_values_out) + , equality_op(equality_op) + , scan_op(scan_op) + , init_value(init_value) + , num_items(num_items) + , stream(stream) + , ptx_version(ptx_version) + {} + + CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED + CUB_RUNTIME_FUNCTION __forceinline__ + DispatchScanByKey(void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + EqualityOp equality_op, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items, + cudaStream_t stream, + bool debug_synchronous, + int ptx_version) + : d_temp_storage(d_temp_storage) + , temp_storage_bytes(temp_storage_bytes) + , d_keys_in(d_keys_in) + , d_values_in(d_values_in) + , d_values_out(d_values_out) + , equality_op(equality_op) + , scan_op(scan_op) + , init_value(init_value) + , num_items(num_items) + , stream(stream) + , ptx_version(ptx_version) + { + CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG + } - return error; - } + template + CUB_RUNTIME_FUNCTION __host__ __forceinline__ cudaError_t + Invoke(InitKernel init_kernel, ScanKernel scan_kernel) + { + using Policy = typename ActivePolicyT::ScanByKeyPolicyT; + using ScanByKeyTileStateT = ReduceByKeyScanTileState; - template - CUB_RUNTIME_FUNCTION __host__ __forceinline__ - cudaError_t Invoke() - { - typedef typename DispatchScanByKey::MaxPolicy MaxPolicyT; - typedef ReduceByKeyScanTileState ScanByKeyTileStateT; - // Ensure kernels are instantiated. - return Invoke( - DeviceScanByKeyInitKernel, - DeviceScanByKeyKernel< - MaxPolicyT, KeysInputIteratorT, ValuesInputIteratorT, ValuesOutputIteratorT, - ScanByKeyTileStateT, EqualityOp, ScanOpT, InitValueT, OffsetT> - ); - } - - - /** - * Internal dispatch routine - */ - CUB_RUNTIME_FUNCTION __forceinline__ - static cudaError_t Dispatch( - void* d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. - size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - KeysInputIteratorT d_keys_in, ///< [in] Iterator to the input sequence of key items - ValuesInputIteratorT d_values_in, ///< [in] Iterator to the input sequence of value items - ValuesOutputIteratorT d_values_out, ///< [out] Iterator to the input sequence of value items - EqualityOp equality_op, ///< [in] Binary equality functor - ScanOpT scan_op, ///< [in] Binary scan functor - InitValueT init_value, ///< [in] Initial value to seed the exclusive scan - OffsetT num_items, ///< [in] Total number of input items (i.e., the length of \p d_in) - cudaStream_t stream) ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + cudaError error = cudaSuccess; + do { - typedef typename DispatchScanByKey::MaxPolicy MaxPolicyT; - - cudaError_t error; - do - { - // Get PTX version - int ptx_version = 0; - if (CubDebug(error = PtxVersion(ptx_version))) break; - - // Create dispatch functor - DispatchScanByKey dispatch( - d_temp_storage, - temp_storage_bytes, + // Get device ordinal + int device_ordinal; + if (CubDebug(error = cudaGetDevice(&device_ordinal))) + { + break; + } + + // Number of input tiles + int tile_size = Policy::BLOCK_THREADS * Policy::ITEMS_PER_THREAD; + int num_tiles = + static_cast(cub::DivideAndRoundUp(num_items, tile_size)); + + // Specify temporary storage allocation requirements + size_t allocation_sizes[2]; + if (CubDebug( + error = ScanByKeyTileStateT::AllocationSize(num_tiles, + allocation_sizes[0]))) + { + break; // bytes needed for tile status descriptors + } + + allocation_sizes[1] = sizeof(KeyT) * (num_tiles + 1); + + // Compute allocation pointers into the single storage blob (or compute + // the necessary size of the blob) + void *allocations[2] = {}; + if (CubDebug(error = AliasTemporaries(d_temp_storage, + temp_storage_bytes, + allocations, + allocation_sizes))) + { + break; + } + + if (d_temp_storage == NULL) + { + // Return if the caller is simply requesting the size of the storage + // allocation + break; + } + + // Return if empty problem + if (num_items == 0) + { + break; + } + + KeyT *d_keys_prev_in = reinterpret_cast(allocations[1]); + + // Construct the tile status interface + ScanByKeyTileStateT tile_state; + if (CubDebug(error = tile_state.Init(num_tiles, + allocations[0], + allocation_sizes[0]))) + { + break; + } + + // Log init_kernel configuration + int init_grid_size = cub::DivideAndRoundUp(num_tiles, + INIT_KERNEL_THREADS); + #ifdef CUB_DETAIL_DEBUG_ENABLE_LOG + _CubLog("Invoking init_kernel<<<%d, %d, 0, %lld>>>()\n", + init_grid_size, + INIT_KERNEL_THREADS, + (long long)stream); + #endif + + // Invoke init_kernel to initialize tile descriptors + THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( + init_grid_size, + INIT_KERNEL_THREADS, + 0, + stream) + .doit(init_kernel, + tile_state, + d_keys_in, + d_keys_prev_in, + tile_size, + num_tiles); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) + { + break; + } + + // Sync the stream if specified to flush runtime errors + + error = detail::DebugSyncStream(stream); + if (CubDebug(error)) + { + break; + } + + // Get SM occupancy for scan_kernel + int scan_sm_occupancy; + if (CubDebug(error = MaxSmOccupancy(scan_sm_occupancy, // out + scan_kernel, + Policy::BLOCK_THREADS))) + { + break; + } + + // Get max x-dimension of grid + int max_dim_x; + if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, + cudaDevAttrMaxGridDimX, + device_ordinal))) + { + break; + } + + // Run grids in epochs (in case number of tiles exceeds max x-dimension + int scan_grid_size = CUB_MIN(num_tiles, max_dim_x); + for (int start_tile = 0; start_tile < num_tiles; + start_tile += scan_grid_size) + { + // Log scan_kernel configuration + #ifdef CUB_DETAIL_DEBUG_ENABLE_LOG + _CubLog("Invoking %d scan_kernel<<<%d, %d, 0, %lld>>>(), %d items " + "per thread, %d SM occupancy\n", + start_tile, + scan_grid_size, + Policy::BLOCK_THREADS, + (long long)stream, + Policy::ITEMS_PER_THREAD, + scan_sm_occupancy); + #endif + + // Invoke scan_kernel + THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( + scan_grid_size, + Policy::BLOCK_THREADS, + 0, + stream) + .doit(scan_kernel, d_keys_in, + d_keys_prev_in, d_values_in, d_values_out, + tile_state, + start_tile, equality_op, scan_op, init_value, - num_items, - stream, - ptx_version - ); - // Dispatch to chained policy - if (CubDebug(error = MaxPolicyT::Invoke(ptx_version, dispatch))) break; + num_items); + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) + { + break; } - while (0); - - return error; - } - - CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED - CUB_RUNTIME_FUNCTION __forceinline__ - static cudaError_t Dispatch( - void* d_temp_storage, - size_t& temp_storage_bytes, - KeysInputIteratorT d_keys_in, - ValuesInputIteratorT d_values_in, - ValuesOutputIteratorT d_values_out, - EqualityOp equality_op, - ScanOpT scan_op, - InitValueT init_value, - OffsetT num_items, - cudaStream_t stream, - bool debug_synchronous) + + // Sync the stream if specified to flush runtime errors + error = detail::DebugSyncStream(stream); + if (CubDebug(error)) + { + break; + } + } + } while (0); + + return error; + } + + template + CUB_RUNTIME_FUNCTION __host__ __forceinline__ cudaError_t Invoke() + { + using MaxPolicyT = typename DispatchScanByKey::MaxPolicy; + using ScanByKeyTileStateT = ReduceByKeyScanTileState; + + // Ensure kernels are instantiated. + return Invoke( + DeviceScanByKeyInitKernel, + DeviceScanByKeyKernel); + } + + /** + * @brief Internal dispatch routine + * + * @param[in] d_temp_storage + * Device-accessible allocation of temporary storage. When `nullptr`, the + * required allocation size is written to `temp_storage_bytes` and no + * work is done. + * + * @param[in,out] temp_storage_bytes + * Reference to size in bytes of `d_temp_storage` allocation + * + * @param[in] d_keys_in + * Iterator to the input sequence of key items + * + * @param[in] d_values_in + * Iterator to the input sequence of value items + * + * @param[out] d_values_out + * Iterator to the input sequence of value items + * + * @param[in] equality_op + * Binary equality functor + * + * @param[in] scan_op + * Binary scan functor + * + * @param[in] init_value + * Initial value to seed the exclusive scan + * + * @param[in] num_items + * Total number of input items (i.e., the length of `d_in`) + * + * @param[in] stream + * CUDA stream to launch kernels within. + */ + CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t + Dispatch(void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + EqualityOp equality_op, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items, + cudaStream_t stream) + { + using MaxPolicyT = typename DispatchScanByKey::MaxPolicy; + + cudaError_t error; + + do { - CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG - - return Dispatch(d_temp_storage, - temp_storage_bytes, - d_keys_in, - d_values_in, - d_values_out, - equality_op, - scan_op, - init_value, - num_items, - stream); - } -}; + // Get PTX version + int ptx_version = 0; + if (CubDebug(error = PtxVersion(ptx_version))) + { + break; + } + + // Create dispatch functor + DispatchScanByKey dispatch(d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_values_in, + d_values_out, + equality_op, + scan_op, + init_value, + num_items, + stream, + ptx_version); + + // Dispatch to chained policy + if (CubDebug(error = MaxPolicyT::Invoke(ptx_version, dispatch))) + { + break; + } + } while (0); + + return error; + } + + CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED + CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t + Dispatch(void *d_temp_storage, + size_t &temp_storage_bytes, + KeysInputIteratorT d_keys_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + EqualityOp equality_op, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items, + cudaStream_t stream, + bool debug_synchronous) + { + CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG + return Dispatch(d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_values_in, + d_values_out, + equality_op, + scan_op, + init_value, + num_items, + stream); + } +}; CUB_NAMESPACE_END diff --git a/cub/thread/thread_store.cuh b/cub/thread/thread_store.cuh index 3af2613d4b..7f711bf061 100644 --- a/cub/thread/thread_store.cuh +++ b/cub/thread/thread_store.cuh @@ -33,9 +33,9 @@ #pragma once -#include "../config.cuh" -#include "../util_ptx.cuh" -#include "../util_type.cuh" +#include +#include +#include CUB_NAMESPACE_BEGIN diff --git a/cub/util_type.cuh b/cub/util_type.cuh index 52cf6c0ed1..0910f11c63 100644 --- a/cub/util_type.cuh +++ b/cub/util_type.cuh @@ -47,12 +47,12 @@ #include #endif +#include #include #include #include #include - CUB_NAMESPACE_BEGIN @@ -312,7 +312,8 @@ struct InputValue if (m_is_future) { m_future_value = other.m_future_value; } else { - m_immediate_value = other.m_immediate_value; + detail::uninitialized_copy(&m_immediate_value, + other.m_immediate_value); } } diff --git a/test/test_device_reduce_by_key.cu b/test/test_device_reduce_by_key.cu index 4e216ff4e9..98a460d669 100644 --- a/test/test_device_reduce_by_key.cu +++ b/test/test_device_reduce_by_key.cu @@ -320,9 +320,11 @@ int Solve( ReductionOpT reduction_op, int num_items) { + using AccumT = cub::detail::accumulator_t; + // First item KeyT previous = h_keys_in[0]; - ValueT aggregate = h_values_in[0]; + AccumT aggregate = h_values_in[0]; int num_segments = 0; // Subsequent items @@ -331,7 +333,7 @@ int Solve( if (!equality_op(previous, h_keys_in[i])) { h_keys_reference[num_segments] = previous; - h_values_reference[num_segments] = aggregate; + h_values_reference[num_segments] = static_cast(aggregate); num_segments++; aggregate = h_values_in[i]; } @@ -343,7 +345,7 @@ int Solve( } h_keys_reference[num_segments] = previous; - h_values_reference[num_segments] = aggregate; + h_values_reference[num_segments] = static_cast(aggregate); num_segments++; return num_segments; diff --git a/test/test_device_scan.cu b/test/test_device_scan.cu index 807bb98a22..6197bda2db 100644 --- a/test/test_device_scan.cu +++ b/test/test_device_scan.cu @@ -76,8 +76,9 @@ struct WrapperFunctor WrapperFunctor(OpT op) : op(op) {} - template - __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const + template + __host__ __device__ __forceinline__ auto operator()(const T &a, const U &b) const + -> decltype(op(a, b)) { return static_cast(op(a, b)); } @@ -511,8 +512,11 @@ void Solve( ScanOpT scan_op, InitialValueT initial_value) { - // Use the initial value type for accumulation per P0571 - using AccumT = InitialValueT; + using AccumT = + cub::detail::accumulator_t< + ScanOpT, + InitialValueT, + cub::detail::value_t>; if (num_items > 0) { @@ -544,9 +548,11 @@ void Solve( ScanOpT scan_op, NullType) { - // When no initial value type is supplied, use InputT for accumulation - // per P0571 - using AccumT = cub::detail::value_t; + using AccumT = + cub::detail::accumulator_t< + ScanOpT, + cub::detail::value_t, + cub::detail::value_t>; if (num_items > 0) { @@ -1013,6 +1019,162 @@ void TestSize( } } +class CustomInputT +{ + char m_val{}; + +public: + __host__ __device__ explicit CustomInputT(char val) + : m_val(val) + {} + + __host__ __device__ int get() const { return static_cast(m_val); } +}; + +class CustomAccumulatorT +{ + int m_val{0}; + int m_magic_value{42}; + + __host__ __device__ CustomAccumulatorT(int val) + : m_val(val) + {} + +public: + __host__ __device__ CustomAccumulatorT() + {} + + __host__ __device__ CustomAccumulatorT(const CustomAccumulatorT &in) + : m_val(in.is_valid() * in.get()) + , m_magic_value(in.is_valid() * 42) + {} + + __host__ __device__ CustomAccumulatorT(const CustomInputT &in) + : m_val(in.get()) + , m_magic_value(42) + {} + + __host__ __device__ void operator=(const CustomInputT &in) + { + if (this->is_valid()) + { + m_val = in.get(); + } + } + + __host__ __device__ void operator=(const CustomAccumulatorT &in) + { + if (this->is_valid() && in.is_valid()) + { + m_val = in.get(); + } + } + + __host__ __device__ CustomAccumulatorT + operator+(const CustomInputT &in) const + { + const int multiplier = this->is_valid(); + return {(m_val + in.get()) * multiplier}; + } + + __host__ __device__ CustomAccumulatorT + operator+(const CustomAccumulatorT &in) const + { + const int multiplier = this->is_valid() && in.is_valid(); + return {(m_val + in.get()) * multiplier}; + } + + __host__ __device__ int get() const { return m_val; } + + __host__ __device__ bool is_valid() const { return m_magic_value == 42; } +}; + +class CustomOutputT +{ + int *m_d_ok_count{}; + int m_expected{}; + +public: + __host__ __device__ CustomOutputT(int *d_ok_count, int expected) + : m_d_ok_count(d_ok_count) + , m_expected(expected) + {} + + __device__ void operator=(const CustomAccumulatorT &accum) const + { + const int ok = accum.is_valid() && (accum.get() == m_expected); + atomicAdd(m_d_ok_count, ok); + } +}; + +__global__ void InitializeTestAccumulatorTypes(int num_items, + int *d_ok_count, + CustomInputT *d_in, + CustomOutputT *d_out) +{ + const int idx = static_cast(threadIdx.x + blockIdx.x * blockDim.x); + + if (idx < num_items) + { + d_in[idx] = CustomInputT(1); + d_out[idx] = CustomOutputT{d_ok_count, idx}; + } +} + +void TestAccumulatorTypes() +{ + const int num_items = 2 * 1024 * 1024; + const int block_size = 256; + const int grid_size = (num_items + block_size - 1) / block_size; + + CustomInputT *d_in{}; + CustomOutputT *d_out{}; + CustomAccumulatorT init{}; + int *d_ok_count{}; + + CubDebugExit(g_allocator.DeviceAllocate((void **)&d_ok_count, sizeof(int))); + CubDebugExit(g_allocator.DeviceAllocate((void **)&d_out, + sizeof(CustomOutputT) * num_items)); + CubDebugExit(g_allocator.DeviceAllocate((void **)&d_in, + sizeof(CustomInputT) * num_items)); + + InitializeTestAccumulatorTypes<<>>(num_items, + d_ok_count, + d_in, + d_out); + + std::uint8_t *d_temp_storage{}; + std::size_t temp_storage_bytes{}; + + CubDebugExit(cub::DeviceScan::ExclusiveScan(d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + cub::Sum{}, + init, + num_items)); + + CubDebugExit( + g_allocator.DeviceAllocate((void **)&d_temp_storage, temp_storage_bytes)); + CubDebugExit(cudaMemset(d_temp_storage, 1, temp_storage_bytes)); + + CubDebugExit(cub::DeviceScan::ExclusiveScan(d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + cub::Sum{}, + init, + num_items)); + + int ok{}; + CubDebugExit(cudaMemcpy(&ok, d_ok_count, sizeof(int), cudaMemcpyDeviceToHost)); + + AssertEquals(ok, num_items); + + CubDebugExit(g_allocator.DeviceFree(d_out)); + CubDebugExit(g_allocator.DeviceFree(d_in)); + CubDebugExit(g_allocator.DeviceFree(d_ok_count)); +} //--------------------------------------------------------------------- @@ -1102,6 +1264,7 @@ int main(int argc, char** argv) TestSize(num_items, TestBar(0, 0), TestBar(1ll << 63, 1 << 31)); + TestAccumulatorTypes(); #endif return 0; diff --git a/test/test_device_scan_by_key.cu b/test/test_device_scan_by_key.cu index f0d503fc10..f3bf841763 100644 --- a/test/test_device_scan_by_key.cu +++ b/test/test_device_scan_by_key.cu @@ -84,8 +84,9 @@ struct WrapperFunctor WrapperFunctor(OpT op) : op(op) {} - template - __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const + template + __host__ __device__ __forceinline__ auto operator()(const T &a, const U &b) const + -> decltype(op(a, b)) { return static_cast(op(a, b)); } @@ -412,8 +413,8 @@ void Solve( InitialValueT initial_value, EqualityOpT equality_op) { - // Use the initial value type for accumulation per P0571 - using AccumT = InitialValueT; + using ValueT = cub::detail::value_t; + using AccumT = cub::detail::accumulator_t; if (num_items > 0) { @@ -453,9 +454,8 @@ void Solve( NullType /*initial_value*/, EqualityOpT equality_op) { - // When no initial value type is supplied, use InputT for accumulation - // per P0571 - using AccumT = cub::detail::value_t; + using ValueT = cub::detail::value_t; + using AccumT = cub::detail::accumulator_t; if (num_items > 0) {