Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -2982,17 +2982,17 @@ void logcumsumexp_grad(const Tensor& x,
const Tensor reshape_x = backend::reshape<T>(x_cast, out_grad_shape);

if (out_grad_dtype == DataType::FLOAT32) {
lowest =
backend::full_with_tensor<T>(out_grad_shape,
std::numeric_limits<float>::lowest(),
out_grad_dtype,
out_grad.place());
lowest = backend::full_with_tensor<T>(
out_grad_shape,
-std::numeric_limits<float>::infinity(),
out_grad_dtype,
out_grad.place());
} else if (out_grad_dtype == DataType::FLOAT64) {
lowest =
backend::full_with_tensor<T>(out_grad_shape,
std::numeric_limits<double>::lowest(),
out_grad_dtype,
out_grad.place());
lowest = backend::full_with_tensor<T>(
out_grad_shape,
-std::numeric_limits<float>::infinity(),
out_grad_dtype,
out_grad.place());
}
const Tensor zero = backend::full_with_tensor<T>(
out_grad_shape, 0.0, out_grad_dtype, out_grad.place());
Expand All @@ -3016,12 +3016,12 @@ void logcumsumexp_grad(const Tensor& x,
const Tensor reshape_x = reshape<T>(x_cast, out_grad_cast.shape());
if (out_grad_dtype == DataType::FLOAT32) {
lowest = full<T>(out_grad_cast.shape(),
std::numeric_limits<float>::lowest(),
-std::numeric_limits<float>::infinity(),
out_grad_dtype,
out_grad_cast.place());
} else if (out_grad_dtype == DataType::FLOAT64) {
lowest = full<T>(out_grad_cast.shape(),
std::numeric_limits<double>::lowest(),
-std::numeric_limits<float>::infinity(),
out_grad_dtype,
out_grad_cast.place());
}
Expand Down
33 changes: 16 additions & 17 deletions paddle/phi/kernels/cpu/cum_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,17 @@ void ScanKernel(const Context& dev_ctx,
bool reverse,
Reducer reducer,
DenseTensor* out) {
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
dev_ctx.template Alloc<T>(out);
if (out && out->numel() == 0) return;

if (x.numel() == 1) {
auto raw_dims = out->dims();
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(raw_dims);
return;
}
auto out_dims = out->dims();

auto out_dims = out->dims();
PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()),
true,
Expand Down Expand Up @@ -162,16 +159,16 @@ struct LogSumExp {
const T& b) const {
auto mi = Eigen::internal::scalar_min_op<T>()(a, b);
auto ma = Eigen::internal::scalar_max_op<T>()(a, b);
if (ma == -Eigen::NumTraits<T>::infinity()) {
return ma;
}

auto sub = Eigen::internal::scalar_difference_op<T>();
auto add = Eigen::internal::scalar_sum_op<T>();
auto exp = Eigen::internal::scalar_exp_op<T>();
auto log1p = Eigen::internal::scalar_log1p_op<T>();
auto cmp_lt =
Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();

auto logsumexp = add(log1p(exp(sub(mi, ma))), ma);
return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? ma : logsumexp;
return add(log1p(exp(sub(mi, ma))), ma);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a,
const T& b) const {
Expand All @@ -186,7 +183,7 @@ struct LogSumExp {

auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma);
return pselect(
pcmp_lt(ma, pset1(Eigen::NumTraits<T>::lowest())), ma, logsumexp);
pcmp_lt(ma, pset1(-Eigen::NumTraits<T>::infinity())), ma, logsumexp);
}
};

Expand All @@ -205,7 +202,7 @@ struct LogSumExpReducer {
}

EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return Eigen::NumTraits<T>::lowest();
return -Eigen::NumTraits<T>::infinity();
}

template <typename Packet>
Expand All @@ -229,12 +226,11 @@ struct LogSumExpReducer {
auto max_reducer = Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>();
auto sum_reducer = Eigen::internal::SumReducer<T>();
auto exp = Eigen::internal::scalar_exp_op<T>();
auto cmp_lt =
Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
auto log = Eigen::internal::scalar_log_op<T>();
auto add = Eigen::internal::scalar_sum_op<T>();

using Eigen::internal::pexp;
using Eigen::internal::pset1;
using Eigen::internal::psub;

// `ma = max(x1, ..., xn)`
Expand All @@ -244,10 +240,13 @@ struct LogSumExpReducer {
//
// `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))`
auto ma = max_reducer.finalizeBoth(saccum, vaccum);
auto logsumexp = add(log(sum_reducer.finalizeBoth(
exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))),
ma);
return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? initialize() : logsumexp;
if (ma == -Eigen::NumTraits<T>::infinity()) {
return ma;
}

auto sum_of_exps = sum_reducer.finalizeBoth(exp(saccum - ma),
pexp(psub(vaccum, pset1(ma))));
return add(log(sum_of_exps), ma);
}
};

Expand Down
125 changes: 27 additions & 98 deletions paddle/phi/kernels/gpu/cum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,24 @@ template <typename T, typename Op>
struct BlockPrefixCallbackOp {
// Running prefix
T running_total_;
T compensation_;
Op op_;

__device__ BlockPrefixCallbackOp(T running_total, Op op)
: running_total_(running_total), op_(op) {}
__device__ BlockPrefixCallbackOp(T identity, Op op)
: running_total_(identity), compensation_(identity), op_(op) {}

// Callback operator to be entered by the first warp of threads in the block.
// tid 0 is responsible for returning a value for seeding the block-wide scan.
__device__ T operator()(T block_aggregate) {
T old_prefix = running_total_;
running_total_ = op_(old_prefix, block_aggregate);

// Kahan Summation
T y = op_(block_aggregate, static_cast<T>(-compensation_));
T t = op_(running_total_, y);
T y_high = op_(t, static_cast<T>(-running_total_));
compensation_ = op_(y_high, static_cast<T>(-y));
running_total_ = t;

return old_prefix;
}
};
Expand Down Expand Up @@ -138,7 +146,7 @@ struct Identity<T, cub::Sum> {

template <typename T>
struct Identity<T, LogAddExp> {
static constexpr T value = std::numeric_limits<T>::lowest();
static constexpr T value = -std::numeric_limits<T>::infinity();
};

template <typename T>
Expand All @@ -154,17 +162,17 @@ __global__ void BlockScanKernel(T* d_out,
bool exclusive,
Op op) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
using CallbackOp = BlockPrefixCallbackOp<MT, Op>;

// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
typedef cub::
BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
BlockLoadT;
typedef cub::BlockStore<MT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
cub::BLOCK_STORE_TRANSPOSE>
BlockStoreT;
typedef cub::BlockScan<MT, BLOCK_THREADS> BlockScanT;
using BlockLoadT = cub::
BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>;
using BlockStoreT = cub::BlockStore<MT,
BLOCK_THREADS,
ITEMS_PER_THREAD,
cub::BLOCK_STORE_TRANSPOSE>;
using BlockScanT = cub::BlockScan<MT, BLOCK_THREADS>;

// Allocate type-safe, repurposable shared memory for collectives
__shared__ union {
typename BlockLoadT::TempStorage load;
Expand All @@ -174,26 +182,20 @@ __global__ void BlockScanKernel(T* d_out,

// Obtain this block's segment of consecutive keys (blocked across threads)
int64_t item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;

for (int64_t bx = blockIdx.x; bx < grid_size; bx += gridDim.x) {
BlockPrefixCallbackOp<MT, Op> prefix_op(Identity<MT, Op>::value, op);
CallbackOp prefix_op(Identity<MT, Op>::value, op);

for (int64_t block_offset = 0; block_offset < scan_size;
block_offset += item_per_block) {
int64_t valid_item = (scan_size - block_offset > item_per_block)
? item_per_block
: (scan_size - block_offset);
if (scan_size < item_per_block) {
valid_item = scan_size;
}

int64_t valid_item = std::min(item_per_block, scan_size - block_offset);
int64_t offset = bx * scan_size + block_offset;

MT thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load)
.Load(d_in + offset, thread_keys, valid_item, 0);

.Load(
d_in + offset, thread_keys, valid_item, Identity<MT, Op>::value);
__syncthreads();

if (exclusive) {
BlockScanT(temp_storage.scan)
.ExclusiveScan(thread_keys, thread_keys, op, prefix_op);
Expand All @@ -209,63 +211,6 @@ __global__ void BlockScanKernel(T* d_out,
}
}

template <typename Context, typename T>
typename std::enable_if<!std::is_same<T, phi::dtype::float16>::value &&
!std::is_same<T, phi::dtype::bfloat16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
const T* in_data,
T* out_data,
int64_t size,
bool reverse,
bool exclusive) {
#ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else
phi::memory_utils::ThrustAllocator<cudaStream_t> allocator(dev_ctx.GetPlace(),
dev_ctx.stream());
const auto& policy = thrust::cuda::par(allocator).on(dev_ctx.stream());
#endif
if (reverse) {
thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in(
thrust::device_pointer_cast(in_data) + size);
thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out(
thrust::device_pointer_cast(out_data) + size);
if (exclusive) {
thrust::exclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
} else {
thrust::inclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out);
}
} else {
if (exclusive) {
thrust::exclusive_scan(policy, in_data, in_data + size, out_data);
} else {
thrust::inclusive_scan(policy, in_data, in_data + size, out_data);
}
}

return;
}

template <typename Context, typename T>
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
const phi::dtype::float16* in_data,
phi::dtype::float16* out_data,
int64_t size,
bool reverse,
bool exclusive) {}

template <typename Context, typename T>
typename std::enable_if<std::is_same<T, phi::dtype::bfloat16>::value>::type
ThrustCumsumKernel(const Context& dev_ctx,
const phi::dtype::bfloat16* in_data,
phi::dtype::bfloat16* out_data,
int64_t size,
bool reverse,
bool exclusive) {}

template <typename T, typename Context, typename Op>
void ScanKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand All @@ -275,11 +220,8 @@ void ScanKernel(const Context& dev_ctx,
bool reverse,
Op op,
DenseTensor* out) {
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
T* out_data = dev_ctx.template Alloc<T>(out);
if (out && out->numel() == 0) return;

// For 0D Tensor
if (out->numel() == 1) {
Expand All @@ -290,8 +232,6 @@ void ScanKernel(const Context& dev_ctx,
}

auto out_dims = out->dims();
auto size = x.numel();

PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()),
true,
Expand All @@ -307,22 +247,11 @@ void ScanKernel(const Context& dev_ctx,

const T* in_data = x.data<T>();

// Use thrust for parallel acceleration when the input size is equal to the
// length of the 'axis' dimension.
if (!std::is_same<T, phi::dtype::float16>::value &&
!std::is_same<T, phi::dtype::bfloat16>::value &&
std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
ThrustCumsumKernel<Context, T>(
dev_ctx, in_data, out_data, size, reverse, exclusive);
return;
}

size_t height = 1;
size_t width = 1;
for (size_t i = 0; i <= axis; i++) {
height *= out_dims[i];
}

for (size_t i = axis + 1; i < out_dims.size(); i++) {
width *= out_dims[i];
}
Expand Down
Loading
Loading