Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions paddle/phi/kernels/cpu/cross_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ void CrossGradKernel(const Context &dev_ctx,
"But received: Input(X/Y).dims() == [%s].",
input_x_dims));
}
auto outer_loops = 1;
for (auto i = 0; i < dim; i++) {
outer_loops *= static_cast<int>(input_x_dims[i]);
int64_t outer_loops = 1;
for (int i = 0; i < dim; i++) {
outer_loops *= input_x_dims[i];
}
auto slice_size = 1;
for (auto i = dim + 1; i < input_x_dims.size(); i++) {
slice_size *= static_cast<int>(input_x_dims[i]);
int64_t slice_size = 1;
for (int i = dim + 1; i < input_x_dims.size(); i++) {
slice_size *= input_x_dims[i];
}

int64_t numel = x.numel();
Expand Down Expand Up @@ -111,12 +111,12 @@ void CrossGradKernel(const Context &dev_ctx,
dev_ctx.template Alloc<T>(output_x_grad);
dev_ctx.template Alloc<T>(output_y_grad);

for (auto i = 0; i < outer_loops; i++) {
for (auto j = 0; j < 3; j++) {
auto dst_pos = (3 * i + j) * slice_size;
auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
for (auto k = 0; k < slice_size; k++) {
for (int64_t i = 0; i < outer_loops; i++) {
for (int64_t j = 0; j < 3; j++) {
int64_t dst_pos = (3 * i + j) * slice_size;
int64_t in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
int64_t in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
for (int64_t k = 0; k < slice_size; k++) {
out_dx_vec[dst_pos + k] =
input_dout_vec[in_pos2 + k] * input_y_vec[in_pos1 + k] -
input_dout_vec[in_pos1 + k] * input_y_vec[in_pos2 + k];
Expand Down
20 changes: 10 additions & 10 deletions paddle/phi/kernels/cpu/cross_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ void CrossKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(output);
return;
}
auto outer_loops = 1;
int64_t outer_loops = 1;
for (auto i = 0; i < dim; i++) {
outer_loops *= static_cast<int>(input_x_dims[i]);
outer_loops *= input_x_dims[i];
}
auto slice_size = 1;
int64_t slice_size = 1;
for (auto i = dim + 1; i < input_x_dims.size(); i++) {
slice_size *= static_cast<int>(input_x_dims[i]);
slice_size *= input_x_dims[i];
}

std::vector<T> input_x_vec, input_y_vec;
Expand All @@ -91,13 +91,13 @@ void CrossKernel(const Context& dev_ctx,

dev_ctx.template Alloc<T>(output);

for (auto i = 0; i < outer_loops; i++) {
for (auto j = 0; j < 3; j++) {
auto dst_pos = (3 * i + j) * slice_size;
auto in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
auto in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;
for (int64_t i = 0; i < outer_loops; i++) {
for (int64_t j = 0; j < 3; j++) {
int64_t dst_pos = (3 * i + j) * slice_size;
int64_t in_pos1 = (3 * i + ((j + 1) % 3)) * slice_size;
int64_t in_pos2 = (3 * i + ((j + 2) % 3)) * slice_size;

for (auto k = 0; k < slice_size; k++) {
for (int64_t k = 0; k < slice_size; k++) {
out_vec[dst_pos + k] =
input_x_vec[in_pos1 + k] * input_y_vec[in_pos2 + k] -
input_x_vec[in_pos2 + k] * input_y_vec[in_pos1 + k];
Expand Down
10 changes: 6 additions & 4 deletions paddle/phi/kernels/gpu/class_center_sample_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ namespace cub = hipcub;
#include "paddle/phi/core/kernel_registry.h"

namespace phi {
#define CUDA_KERNEL_LOOP(i, n) \
for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x, \
step = blockDim.x * gridDim.x; \
i < (n); \
#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \
for (index_type i = blockIdx.x * blockDim.x + threadIdx.x, \
step = blockDim.x * gridDim.x; \
i < (n); \
i += step)

#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int32_t)

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;

Expand Down
44 changes: 22 additions & 22 deletions paddle/phi/kernels/gpu/cross_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,21 @@
namespace phi {

template <typename T>
__global__ void CrossGrad(const T* x,
const T* y,
const T* out,
T* out_dx,
T* out_dy,
const int64_t stride,
const int64_t N,
phi::funcs::IndexCalculator<int> index_calculator) {
CUDA_KERNEL_LOOP(i, N) {
__global__ void CrossGrad(
const T* x,
const T* y,
const T* out,
T* out_dx,
T* out_dy,
const int64_t stride,
const int64_t N,
phi::funcs::IndexCalculator<int64_t> index_calculator) {
CUDA_KERNEL_LOOP_TYPE(i, N, int64_t) {
int64_t offset = index_calculator(i);

auto pos0 = offset + 0 * stride;
auto pos1 = offset + 1 * stride;
auto pos2 = offset + 2 * stride;
int64_t pos0 = offset + 0 * stride;
int64_t pos1 = offset + 1 * stride;
int64_t pos2 = offset + 2 * stride;

using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

Expand Down Expand Up @@ -127,16 +128,16 @@ void CrossGradKernel(const Context& dev_ctx,
std::vector<int64_t> full_strides;
std::vector<int64_t> merged_dims;

for (int64_t i = 0; i < dim; i++) {
for (int i = 0; i < dim; i++) {
if (i == 0) {
merged_dims.push_back(input_x_dims[i]);
} else {
merged_dims[0] *= input_x_dims[i];
}
}
int64_t merge_axis = merged_dims.size();
int merge_axis = merged_dims.size();
merged_dims.push_back(input_x_dims[dim]);
for (int64_t i = dim + 1; i < input_x_dims.size(); i++) {
for (int i = dim + 1; i < input_x_dims.size(); i++) {
if (i == dim + 1) {
merged_dims.push_back(input_x_dims[i]);
} else {
Expand All @@ -145,7 +146,7 @@ void CrossGradKernel(const Context& dev_ctx,
}

int64_t full_dim = 1;
for (int64_t i = 0; i < merged_dims.size(); i++) {
for (int i = 0; i < merged_dims.size(); i++) {
full_strides.insert(full_strides.begin(), full_dim);
full_dim *= merged_dims[merged_dims.size() - i - 1];
if (i == merge_axis) {
Expand All @@ -154,7 +155,7 @@ void CrossGradKernel(const Context& dev_ctx,
cal_dims.push_back(i);
}
int64_t left_dim = 1;
for (int64_t i = merged_dims.size() - 1; i >= 0; i--) {
for (int i = merged_dims.size() - 1; i >= 0; i--) {
if (i == merge_axis) {
continue;
}
Expand All @@ -168,11 +169,11 @@ void CrossGradKernel(const Context& dev_ctx,
const auto* input_out_grad_data = input_out_grad.data<T>();
auto* output_x_grad_data = dev_ctx.template Alloc<T>(x_grad);
auto* output_y_grad_data = dev_ctx.template Alloc<T>(y_grad);
auto index_calculator = phi::funcs::IndexCalculator<int>(
merged_dims.size() - 1, cal_dims, left_strides, full_strides);

backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3);
auto index_calculator = phi::funcs::IndexCalculator<int64_t>(
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
if (IsComplexType(x.dtype())) {
DenseTensor x_conj, y_conj;
DenseTensorMeta meta_xy(x.dtype(), x.dims());
Expand All @@ -189,7 +190,6 @@ void CrossGradKernel(const Context& dev_ctx,
input_y_data, numel, input_y_conj_data);
for_range(functor_x);
for_range(functor_y);

CrossGrad<<<config.block_per_grid,
config.thread_per_block,
0,
Expand All @@ -199,7 +199,7 @@ void CrossGradKernel(const Context& dev_ctx,
output_x_grad_data,
output_y_grad_data,
full_strides[merge_axis],
numel / 3,
static_cast<int64_t>(numel / 3),
index_calculator);
} else {
CrossGrad<<<config.block_per_grid,
Expand All @@ -211,7 +211,7 @@ void CrossGradKernel(const Context& dev_ctx,
output_x_grad_data,
output_y_grad_data,
full_strides[merge_axis],
numel / 3,
static_cast<int64_t>(numel / 3),
index_calculator);
}
}
Expand Down
26 changes: 13 additions & 13 deletions paddle/phi/kernels/gpu/cross_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ __global__ void Cross(const T* x,
T* out,
const int64_t stride,
const int64_t N,
phi::funcs::IndexCalculator<int> index_calculator) {
CUDA_KERNEL_LOOP(i, N) {
phi::funcs::IndexCalculator<int64_t> index_calculator) {
CUDA_KERNEL_LOOP_TYPE(i, N, int64_t) {
int64_t offset = index_calculator(i);

auto pos0 = offset + 0 * stride;
auto pos1 = offset + 1 * stride;
auto pos2 = offset + 2 * stride;
int64_t pos0 = offset + 0 * stride;
int64_t pos1 = offset + 1 * stride;
int64_t pos2 = offset + 2 * stride;

using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

Expand Down Expand Up @@ -111,16 +111,16 @@ void CrossKernel(const Context& dev_ctx,
std::vector<int64_t> full_strides;
std::vector<int64_t> merged_dims;

for (int64_t i = 0; i < dim; i++) {
for (int i = 0; i < dim; i++) {
if (i == 0) {
merged_dims.push_back(input_x_dims[i]);
} else {
merged_dims[0] *= input_x_dims[i];
}
}
int64_t merge_axis = merged_dims.size();
int merge_axis = merged_dims.size();
merged_dims.push_back(input_x_dims[dim]);
for (int64_t i = dim + 1; i < input_x_dims.size(); i++) {
for (int i = dim + 1; i < input_x_dims.size(); i++) {
if (i == dim + 1) {
merged_dims.push_back(input_x_dims[i]);
} else {
Expand All @@ -129,7 +129,7 @@ void CrossKernel(const Context& dev_ctx,
}

int64_t full_dim = 1;
for (int64_t i = 0; i < merged_dims.size(); i++) {
for (int i = 0; i < merged_dims.size(); i++) {
full_strides.insert(full_strides.begin(), full_dim);
full_dim *= merged_dims[merged_dims.size() - i - 1];
if (i == merge_axis) {
Expand All @@ -138,7 +138,7 @@ void CrossKernel(const Context& dev_ctx,
cal_dims.push_back(i);
}
int64_t left_dim = 1;
for (int64_t i = merged_dims.size() - 1; i >= 0; i--) {
for (int i = merged_dims.size() - 1; i >= 0; i--) {
if (i == merge_axis) {
continue;
}
Expand All @@ -149,21 +149,21 @@ void CrossKernel(const Context& dev_ctx,
const auto* input_x_data = input_x.data<T>();
const auto* input_y_data = input_y.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
auto index_calculator = phi::funcs::IndexCalculator<int>(
merged_dims.size() - 1, cal_dims, left_strides, full_strides);

int64_t numel = x.numel();
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3);

auto index_calculator = phi::funcs::IndexCalculator<int64_t>(
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
Cross<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_x_data,
input_y_data,
out_data,
full_strides[merge_axis],
numel / 3,
static_cast<int64_t>(numel / 3),
index_calculator);
}
} // namespace phi
Expand Down
Loading