Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions paddle/phi/kernels/cpu/cum_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,18 @@ void ScanKernel(const Context& dev_ctx,
bool reverse,
Reducer reducer,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
dev_ctx.template Alloc<T>(out);

if (x.numel() == 1) {
auto raw_dims = out->dims();
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(raw_dims);
return;
}
auto out_dims = out->dims();

auto out_dims = out->dims();
PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()),
true,
Expand Down
148 changes: 37 additions & 111 deletions paddle/phi/kernels/gpu/cum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,6 @@ __global__ void MatrixRowReverse(const T* matrix_data,
}
}

template <typename T, typename Op>
struct BlockPrefixCallbackOp {
// Running prefix
T running_total_;
Op op_;

__device__ BlockPrefixCallbackOp(T running_total, Op op)
: running_total_(running_total), op_(op) {}

// Callback operator to be entered by the first warp of threads in the block.
// tid 0 is responsible for returning a value for seeding the block-wide scan.
__device__ T operator()(T block_aggregate) {
T old_prefix = running_total_;
running_total_ = op_(old_prefix, block_aggregate);
return old_prefix;
}
};

// No bank-conflict transpose
template <typename T, int TILE_DIM, int BLOCK_ROWS>
__global__ void MatrixTranspose(T* odata,
Expand Down Expand Up @@ -146,25 +128,47 @@ struct Identity<T, ComplexSum> {
static constexpr T value = {0, 0};
};

template <typename T, typename Op>
struct BlockPrefixCallbackOp {
// Running prefix
T running_total_;
Op op_;

__device__ BlockPrefixCallbackOp(T running_total, Op op)
: running_total_(running_total), op_(op) {}

// Callback operator to be entered by the first warp of threads in the block.
// tid 0 is responsible for returning a value for seeding the block-wide scan.
__device__ T operator()(T block_aggregate) {
T old_prefix = running_total_;
running_total_ = op_(old_prefix, block_aggregate);
return old_prefix;
}
};

// TODO(cangtianhuang): accelerate kernel by using chunk and d_agg
template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD, typename Op>
__global__ void BlockScanKernel(T* d_out,
const T* d_in,
int64_t grid_size,
int64_t scan_size,
bool exclusive,
Op op) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
// float still results in precision loss, so promote to double here
using MT = std::conditional_t<
std::is_same_v<typename phi::dtype::MPTypeTrait<T>::Type, float>,
double,
typename phi::dtype::MPTypeTrait<T>::Type>;

// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
typedef cub::
BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
BlockLoadT;
typedef cub::BlockStore<MT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
cub::BLOCK_STORE_TRANSPOSE>
BlockStoreT;
typedef cub::BlockScan<MT, BLOCK_THREADS> BlockScanT;
using BlockLoadT = cub::
BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>;
using BlockStoreT = cub::BlockStore<MT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
cub::BLOCK_STORE_TRANSPOSE>;
using BlockScanT = cub::BlockScan<MT, BLOCK_THREADS>;

// Allocate type-safe, repurposable shared memory for collectives
__shared__ union {
typename BlockLoadT::TempStorage load;
Expand All @@ -174,26 +178,20 @@ __global__ void BlockScanKernel(T* d_out,

// Obtain this block's segment of consecutive keys (blocked across threads)
int64_t item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;

for (int64_t bx = blockIdx.x; bx < grid_size; bx += gridDim.x) {
BlockPrefixCallbackOp<MT, Op> prefix_op(Identity<MT, Op>::value, op);

for (int64_t block_offset = 0; block_offset < scan_size;
block_offset += item_per_block) {
int64_t valid_item = (scan_size - block_offset > item_per_block)
? item_per_block
: (scan_size - block_offset);
if (scan_size < item_per_block) {
valid_item = scan_size;
}

int64_t valid_item = std::min(item_per_block, scan_size - block_offset);
int64_t offset = bx * scan_size + block_offset;

MT thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load)
.Load(d_in + offset, thread_keys, valid_item, 0);

.Load(
d_in + offset, thread_keys, valid_item, Identity<MT, Op>::value);
__syncthreads();

if (exclusive) {
BlockScanT(temp_storage.scan)
.ExclusiveScan(thread_keys, thread_keys, op, prefix_op);
Expand All @@ -209,63 +207,6 @@ __global__ void BlockScanKernel(T* d_out,
}
}

template <typename Context, typename T>
typename std::enable_if<!std::is_same<T, phi::dtype::float16>::value &&
!std::is_same<T, phi::dtype::bfloat16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
const T* in_data,
T* out_data,
int64_t size,
bool reverse,
bool exclusive) {
#ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
phi::memory_utils::ThrustAllocator<cudaStream_t> allocator(dev_ctx.GetPlace(),
dev_ctx.stream());
const auto& policy = thrust::cuda::par(allocator).on(dev_ctx.stream());
#endif
if (reverse) {
thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in(
thrust::device_pointer_cast(in_data) + size);
thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out(
thrust::device_pointer_cast(out_data) + size);
if (exclusive) {
thrust::exclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
} else {
thrust::inclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
}
} else {
if (exclusive) {
thrust::exclusive_scan(policy, in_data, in_data + size, out_data);
} else {
thrust::inclusive_scan(policy, in_data, in_data + size, out_data);
}
}

return;
}

template <typename Context, typename T>
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
const phi::dtype::float16* in_data,
phi::dtype::float16* out_data,
int64_t size,
bool reverse,
bool exclusive) {}

template <typename Context, typename T>
typename std::enable_if<std::is_same<T, phi::dtype::bfloat16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
const phi::dtype::bfloat16* in_data,
phi::dtype::bfloat16* out_data,
int64_t size,
bool reverse,
bool exclusive) {}

template <typename T, typename Context, typename Op>
void ScanKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand All @@ -275,12 +216,10 @@ void ScanKernel(const Context& dev_ctx,
bool reverse,
Op op,
DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
T* out_data = dev_ctx.template Alloc<T>(out);

// For 0D Tensor
if (out->numel() == 1) {
auto raw_dims = out->dims();
Expand All @@ -290,8 +229,6 @@ void ScanKernel(const Context& dev_ctx,
}

auto out_dims = out->dims();
auto size = x.numel();

PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()),
true,
Expand All @@ -307,22 +244,11 @@ void ScanKernel(const Context& dev_ctx,

const T* in_data = x.data<T>();

// Use thrust for parallel acceleration when the input size is equal to the
// length of the 'axis' dimension.
if (!std::is_same<T, phi::dtype::float16>::value &&
!std::is_same<T, phi::dtype::bfloat16>::value &&
std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
ThrustCumsumKernel<Context, T>(
dev_ctx, in_data, out_data, size, reverse, exclusive);
return;
}

size_t height = 1;
size_t width = 1;
for (size_t i = 0; i <= axis; i++) {
height *= out_dims[i];
}

for (size_t i = axis + 1; i < out_dims.size(); i++) {
width *= out_dims[i];
}
Expand Down
3 changes: 1 addition & 2 deletions paddle/phi/kernels/xpu/cum_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ void CumsumKernel(const Context& dev_ctx,
bool reverse,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out);
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
dev_ctx.template Alloc<T>(out);

if (x.numel() == 1) {
int r = xpu::copy<XPUType>(dev_ctx.x_context(),
Expand Down