@@ -148,43 +148,38 @@ struct Identity<T, ComplexSum> {
148148template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD, typename Op>
149149__global__ void BlockScanKernel (T* d_out,
150150 const T* d_in,
151+ T* d_agg,
151152 int64_t grid_size,
152153 int64_t scan_size,
153154 bool exclusive,
154155 Op op) {
155156 using MT = typename phi::dtype::MPTypeTrait<T>::Type;
156157
157- // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
158- typedef cub::
159- BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
160- BlockLoadT;
161- typedef cub::BlockStore<MT,
162- BLOCK_THREADS,
163- ITEMS_PER_THREAD,
164- cub::BLOCK_STORE_TRANSPOSE>
165- BlockStoreT;
166- typedef cub::BlockScan<MT, BLOCK_THREADS> BlockScanT;
158+ // Specialize BlockLoad, BlockStore, BlockScanT, BlockReduceT collective types
159+ using BlockLoadT = cub::BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
160+ using BlockStoreT = cub::BlockStore<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE>;
161+ using BlockScanT = cub::BlockScan<MT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>;
162+ using BlockReduceT = cub::BlockReduce<MT, BLOCK_THREADS>;
163+
167164 // Allocate type-safe, repurposable shared memory for collectives
168165 __shared__ union {
169166 typename BlockLoadT::TempStorage load;
170167 typename BlockStoreT::TempStorage store;
171168 typename BlockScanT::TempStorage scan;
169+ typename BlockReduceT::TempStorage reduce;
172170 } temp_storage;
173171
174172 // Obtain this block's segment of consecutive keys (blocked across threads)
175173 int64_t item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
176174
177175 for (int64_t bx = blockIdx .x ; bx < grid_size; bx += gridDim .x ) {
176+ int64_t block_offset = bx * scan_size;
178177 BlockPrefixCallbackOp<MT, Op> prefix_op (Identity<MT, Op>::value, op);
179178
180- for (int64_t block_offset = 0 ; block_offset < scan_size;
181- block_offset += item_per_block) {
179+ for (int64_t offset = 0 ; offset < scan_size; offset += item_per_block) {
182180 int64_t valid_item = (scan_size - block_offset > item_per_block)
183181 ? item_per_block
184182 : (scan_size - block_offset);
185- if (scan_size < item_per_block) {
186- valid_item = scan_size;
187- }
188183
189184 int64_t offset = bx * scan_size + block_offset;
190185
@@ -376,7 +371,6 @@ void ScanKernel(const Context& dev_ctx,
376371 if (!transpose && !reverse) {
377372 BlockScanKernel<T, 128 , 4 , Op><<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
378373 out_data, in_data, grid_size, scan_size, exclusive, op);
379-
380374 } else {
381375 BlockScanKernel<T, 128 , 4 , Op><<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
382376 next_out_data, next_in_data, grid_size, scan_size, exclusive, op);
0 commit comments