Skip to content

Commit 300ff1c

Browse files
committed
WIP on (no branch): fd03c70 fix sgn 0size
2 parents fd03c70 + 3ab89e5 commit 300ff1c

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

paddle/phi/kernels/gpu/cum_kernel.cu

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -148,43 +148,38 @@ struct Identity<T, ComplexSum> {
148148
template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD, typename Op>
149149
__global__ void BlockScanKernel(T* d_out,
150150
const T* d_in,
151+
T* d_agg,
151152
int64_t grid_size,
152153
int64_t scan_size,
153154
bool exclusive,
154155
Op op) {
155156
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
156157

157-
// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
158-
typedef cub::
159-
BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
160-
BlockLoadT;
161-
typedef cub::BlockStore<MT,
162-
BLOCK_THREADS,
163-
ITEMS_PER_THREAD,
164-
cub::BLOCK_STORE_TRANSPOSE>
165-
BlockStoreT;
166-
typedef cub::BlockScan<MT, BLOCK_THREADS> BlockScanT;
158+
// Specialize BlockLoad, BlockStore, BlockScanT, BlockReduceT collective types
159+
using BlockLoadT = cub::BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
160+
using BlockStoreT = cub::BlockStore<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE>;
161+
using BlockScanT = cub::BlockScan<MT, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>;
162+
using BlockReduceT = cub::BlockReduce<MT, BLOCK_THREADS>;
163+
167164
// Allocate type-safe, repurposable shared memory for collectives
168165
__shared__ union {
169166
typename BlockLoadT::TempStorage load;
170167
typename BlockStoreT::TempStorage store;
171168
typename BlockScanT::TempStorage scan;
169+
typename BlockReduceT::TempStorage reduce;
172170
} temp_storage;
173171

174172
// Obtain this block's segment of consecutive keys (blocked across threads)
175173
int64_t item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
176174

177175
for (int64_t bx = blockIdx.x; bx < grid_size; bx += gridDim.x) {
176+
int64_t block_offset = bx * scan_size;
178177
BlockPrefixCallbackOp<MT, Op> prefix_op(Identity<MT, Op>::value, op);
179178

180-
for (int64_t block_offset = 0; block_offset < scan_size;
181-
block_offset += item_per_block) {
179+
for (int64_t offset = 0; offset < scan_size; offset += item_per_block) {
182180
int64_t valid_item = (scan_size - block_offset > item_per_block)
183181
? item_per_block
184182
: (scan_size - block_offset);
185-
if (scan_size < item_per_block) {
186-
valid_item = scan_size;
187-
}
188183

189184
int64_t offset = bx * scan_size + block_offset;
190185

@@ -376,7 +371,6 @@ void ScanKernel(const Context& dev_ctx,
376371
if (!transpose && !reverse) {
377372
BlockScanKernel<T, 128, 4, Op><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
378373
out_data, in_data, grid_size, scan_size, exclusive, op);
379-
380374
} else {
381375
BlockScanKernel<T, 128, 4, Op><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
382376
next_out_data, next_in_data, grid_size, scan_size, exclusive, op);

0 commit comments

Comments
 (0)