Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 83 additions & 3 deletions paddle/phi/kernels/gpu/cum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,16 @@ struct BlockPrefixCallbackOp<T, LogAddExp> {
LogAddExp op_;

__device__ BlockPrefixCallbackOp(T identity, LogAddExp op)
: max_so_far_(identity), scaled_sum_(0.0), compensation_(0.0), op_(op) {}
: max_so_far_(identity),
scaled_sum_(static_cast<T>(0.0)),
compensation_(static_cast<T>(0.0)),
op_(op) {}

__device__ T operator()(T block_aggregate) {
if (scaled_sum_ == 0.0) {
max_so_far_ = block_aggregate;
scaled_sum_ = 1.0;
compensation_ = 0.0;
scaled_sum_ = static_cast<T>(1.0);
compensation_ = static_cast<T>(0.0);
return std::numeric_limits<T>::lowest();
}

Expand Down Expand Up @@ -255,6 +258,74 @@ __global__ void BlockScanKernel(T* d_out,
}
}

template <typename Context, typename T>
void ThrustCumsumKernel(const Context& dev_ctx,
const T* in_data,
T* out_data,
int64_t size,
bool reverse,
bool exclusive) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;

#ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
phi::memory_utils::ThrustAllocator<cudaStream_t> allocator(dev_ctx.GetPlace(),
dev_ctx.stream());
const auto& policy = thrust::cuda::par(allocator).on(dev_ctx.stream());
#endif

if constexpr (std::is_same_v<T, MT>) {
if (reverse) {
thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in(
thrust::device_pointer_cast(in_data) + size);
thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out(
thrust::device_pointer_cast(out_data) + size);
if (exclusive) {
thrust::exclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
} else {
thrust::inclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
}
} else {
if (exclusive) {
thrust::exclusive_scan(policy, in_data, in_data + size, out_data);
} else {
thrust::inclusive_scan(policy, in_data, in_data + size, out_data);
}
}
} else {
thrust::device_vector<MT> tmp_in(size);
thrust::device_vector<MT> tmp_out(size);
thrust::copy(policy, in_data, in_data + size, tmp_in.begin());

auto tmp_in_begin = tmp_in.begin();
auto tmp_in_end = tmp_in.end();
auto tmp_out_begin = tmp_out.begin();

if (reverse) {
auto reversed_in = tmp_in.rbegin();
auto reversed_out = tmp_out.rbegin();
if (exclusive) {
thrust::exclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
} else {
thrust::inclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
}
} else {
if (exclusive) {
thrust::exclusive_scan(policy, tmp_in_begin, tmp_in_end, tmp_out_begin);
} else {
thrust::inclusive_scan(policy, tmp_in_begin, tmp_in_end, tmp_out_begin);
}
}

thrust::copy(policy, tmp_out.begin(), tmp_out.end(), out_data);
}
}

template <typename T, typename Context, typename Op>
void ScanKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -295,6 +366,15 @@ void ScanKernel(const Context& dev_ctx,

const T* in_data = x.data<T>();

// Use thrust for parallel acceleration when the input size is equal to the
// length of the 'axis' dimension (i.e., it's a 1D scan).
int64_t size = x.numel();
if (std::is_same_v<Op, cub::Sum> && size == out_dims[axis]) {
ThrustCumsumKernel<Context, T>(
dev_ctx, in_data, out_data, size, reverse, exclusive);
return;
}

size_t height = 1;
size_t width = 1;
for (size_t i = 0; i <= axis; i++) {
Expand Down
Loading