@@ -162,13 +162,16 @@ struct BlockPrefixCallbackOp<T, LogAddExp> {
162162 LogAddExp op_;
163163
164164 __device__ BlockPrefixCallbackOp (T identity, LogAddExp op)
165- : max_so_far_(identity), scaled_sum_(0.0 ), compensation_(0.0 ), op_(op) {}
165+ : max_so_far_(identity),
166+ scaled_sum_(static_cast <T>(0.0 )),
167+ compensation_(static_cast <T>(0.0 )),
168+ op_(op) {}
166169
167170 __device__ T operator ()(T block_aggregate) {
168171 if (scaled_sum_ == 0.0 ) {
169172 max_so_far_ = block_aggregate;
170- scaled_sum_ = 1.0 ;
171- compensation_ = 0.0 ;
173+ scaled_sum_ = static_cast <T>( 1.0 ) ;
174+ compensation_ = static_cast <T>( 0.0 ) ;
172175 return std::numeric_limits<T>::lowest ();
173176 }
174177
@@ -255,6 +258,74 @@ __global__ void BlockScanKernel(T* d_out,
255258 }
256259}
257260
261+ template <typename Context, typename T>
262+ void ThrustCumsumKernel (const Context& dev_ctx,
263+ const T* in_data,
264+ T* out_data,
265+ int64_t size,
266+ bool reverse,
267+ bool exclusive) {
268+ using MT = typename phi::dtype::MPTypeTrait<T>::Type;
269+
270+ #ifdef __HIPCC__
271+ const auto & policy = thrust::hip::par.on (dev_ctx.stream ());
272+ #else
273+ phi::memory_utils::ThrustAllocator<cudaStream_t> allocator (dev_ctx.GetPlace (),
274+ dev_ctx.stream ());
275+ const auto & policy = thrust::cuda::par (allocator).on (dev_ctx.stream ());
276+ #endif
277+
278+ if constexpr (std::is_same_v<T, MT>) {
279+ if (reverse) {
280+ thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in (
281+ thrust::device_pointer_cast (in_data) + size);
282+ thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out (
283+ thrust::device_pointer_cast (out_data) + size);
284+ if (exclusive) {
285+ thrust::exclusive_scan (
286+ policy, reversed_in, reversed_in + size, reversed_out);
287+ } else {
288+ thrust::inclusive_scan (
289+ policy, reversed_in, reversed_in + size, reversed_out);
290+ }
291+ } else {
292+ if (exclusive) {
293+ thrust::exclusive_scan (policy, in_data, in_data + size, out_data);
294+ } else {
295+ thrust::inclusive_scan (policy, in_data, in_data + size, out_data);
296+ }
297+ }
298+ } else {
299+ thrust::device_vector<MT> tmp_in (size);
300+ thrust::device_vector<MT> tmp_out (size);
301+ thrust::copy (policy, in_data, in_data + size, tmp_in.begin ());
302+
303+ auto tmp_in_begin = tmp_in.begin ();
304+ auto tmp_in_end = tmp_in.end ();
305+ auto tmp_out_begin = tmp_out.begin ();
306+
307+ if (reverse) {
308+ auto reversed_in = tmp_in.rbegin ();
309+ auto reversed_out = tmp_out.rbegin ();
310+ if (exclusive) {
311+ thrust::exclusive_scan (
312+ policy, reversed_in, reversed_in + size, reversed_out);
313+ } else {
314+ thrust::inclusive_scan (
315+ policy, reversed_in, reversed_in + size, reversed_out);
316+ }
317+ } else {
318+ if (exclusive) {
319+ thrust::exclusive_scan (policy, tmp_in_begin, tmp_in_end, tmp_out_begin);
320+ } else {
321+ thrust::inclusive_scan (policy, tmp_in_begin, tmp_in_end, tmp_out_begin);
322+ }
323+ }
324+
325+ thrust::copy (policy, tmp_out.begin (), tmp_out.end (), out_data);
326+ }
327+ }
328+
258329template <typename T, typename Context, typename Op>
259330void ScanKernel (const Context& dev_ctx,
260331 const DenseTensor& x,
@@ -295,6 +366,15 @@ void ScanKernel(const Context& dev_ctx,
295366
296367 const T* in_data = x.data <T>();
297368
369+ // Use thrust for parallel acceleration when the input size is equal to the
370+ // length of the 'axis' dimension (i.e., it's a 1D scan).
371+ int64_t size = x.numel ();
372+ if (std::is_same_v<Op, cub::Sum> && size == out_dims[axis]) {
373+ ThrustCumsumKernel<Context, T>(
374+ dev_ctx, in_data, out_data, size, reverse, exclusive);
375+ return ;
376+ }
377+
298378 size_t height = 1 ;
299379 size_t width = 1 ;
300380 for (size_t i = 0 ; i <= axis; i++) {
0 commit comments