Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix for overflow of valid_items in warp reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
dumerrill committed Feb 15, 2018
1 parent 17fbfba commit fb33c9f
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions cub/block/specializations/block_reduce_warp_reductions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ struct BlockReduceWarpReductions

// Thread fields
_TempStorage &temp_storage;
unsigned int linear_tid;
unsigned int warp_id;
unsigned int lane_id;
int linear_tid;
int warp_id;
int lane_id;


/// Constructor
Expand Down Expand Up @@ -169,13 +169,11 @@ struct BlockReduceWarpReductions
T input, ///< [in] Calling thread's input partial reductions
int num_valid) ///< [in] Number of valid elements (may be less than BLOCK_THREADS)
{
cub::Sum reduction_op;
unsigned int warp_offset = warp_id * LOGICAL_WARP_SIZE;
unsigned int warp_num_valid = (FULL_TILE && EVEN_WARP_MULTIPLE) ?
cub::Sum reduction_op;
int warp_offset = (warp_id * LOGICAL_WARP_SIZE);
int warp_num_valid = ((FULL_TILE && EVEN_WARP_MULTIPLE) || (warp_offset + LOGICAL_WARP_SIZE <= num_valid)) ?
LOGICAL_WARP_SIZE :
(warp_offset < num_valid) ?
num_valid - warp_offset :
0;
num_valid - warp_offset;

// Warp reduction in every warp
T warp_aggregate = WarpReduce(temp_storage.warp_reduce[warp_id]).template Reduce<(FULL_TILE && EVEN_WARP_MULTIPLE)>(
Expand All @@ -197,12 +195,10 @@ struct BlockReduceWarpReductions
int num_valid, ///< [in] Number of valid elements (may be less than BLOCK_THREADS)
ReductionOp reduction_op) ///< [in] Binary reduction operator
{
unsigned int warp_offset = warp_id * LOGICAL_WARP_SIZE;
unsigned int warp_num_valid = (FULL_TILE && EVEN_WARP_MULTIPLE) ?
int warp_offset = warp_id * LOGICAL_WARP_SIZE;
int warp_num_valid = ((FULL_TILE && EVEN_WARP_MULTIPLE) || (warp_offset + LOGICAL_WARP_SIZE <= num_valid)) ?
LOGICAL_WARP_SIZE :
(warp_offset < static_cast<unsigned int>(num_valid)) ?
num_valid - warp_offset :
0;
num_valid - warp_offset;

// Warp reduction in every warp
T warp_aggregate = WarpReduce(temp_storage.warp_reduce[warp_id]).template Reduce<(FULL_TILE && EVEN_WARP_MULTIPLE)>(
Expand Down

0 comments on commit fb33c9f

Please sign in to comment.