Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified RandomKernel with Kernel Primitive API #39666

Merged
merged 4 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions paddle/fluid/operators/gaussian_random_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/distribution_helper.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/index_impl.cu.h"

DECLARE_bool(use_curand);

Expand Down Expand Up @@ -65,7 +66,6 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<int64_t> index_sequence_begin(0);
auto shape = GetShape(context);
tensor->Resize(shape);

Expand All @@ -88,15 +88,13 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed_offset.first, gen_offset));
auto func =
GaussianGenerator<T>(mean, std, seed_offset.first, gen_offset);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
}
} else {
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
auto func = GaussianGenerator<T>(mean, std, seed);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
}
}
};
Expand All @@ -116,23 +114,22 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<int64_t> index_sequence_begin(0);
int64_t size = tensor->numel();

int device_id = context.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
auto& dev_cxt =
context.template device_context<platform::CUDADeviceContext>();

if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed_offset.first,
seed_offset.second));
auto func = GaussianGenerator<T>(mean, std, seed_offset.first,
seed_offset.second);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
} else {
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
auto func = GaussianGenerator<T>(mean, std, seed);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
}
}
};
Expand Down
97 changes: 97 additions & 0 deletions paddle/fluid/operators/index_impl.cu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/distribution_helper.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"

namespace paddle {
namespace operators {

namespace kps = phi::kps;
template <typename T, typename Functor, int VecSize>
__global__ void VectorizedIndexKernel(T *out, int numel, int main_offset,
Functor func) {
int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
int args[VecSize];
T result[VecSize];
for (; data_offset < main_offset; data_offset += stride) {
kps::InitWithDataIndex<int, VecSize, 1, 1>(&args[0], data_offset);
kps::ElementwiseUnary<int, T, VecSize, 1, 1, Functor>(&result[0], &args[0],
func);
kps::WriteData<T, VecSize, 1, 1, false>(out + data_offset, &result[0],
BLOCK_NUM_X * VecSize);
}
int num = numel - data_offset;
if (numel > 0) {
kps::InitWithDataIndex<int, VecSize, 1, 1>(&args[0], data_offset);
kps::ElementwiseUnary<int, T, VecSize, 1, 1, Functor>(&result[0], &args[0],
func);
kps::WriteData<T, VecSize, 1, 1, true>(out + data_offset, &result[0], num);
}
}

template <typename T, typename Functor>
void IndexKernel(const KPDevice &dev_ctx, Tensor *out, Functor func) {
int numel = out->numel();
T *out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (numel <= 0) return;
int vec_size = paddle::platform::GetVectorizedSize((out->data<T>()));
#ifdef PADDLE_WITH_XPU_KP
int block = 64;
int grid = 8;
auto stream = dev_ctx.x_context()->xpu_stream;
#else
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
auto stream = dev_ctx.stream();
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPU的线程配置是否合并到GetGpuLaunchConfig1D里面更好?这里线程配置就不用写分支了。包括stream的获取。可以考虑优化下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以的, 后面会加进去


int main_offset = (numel / (vec_size * block)) * vec_size * block;
switch (vec_size) {
case 4:
VectorizedIndexKernel<T, Functor, 4><<<grid, block, 0, stream>>>(
out_data, numel, main_offset, func);
break;
case 2:
VectorizedIndexKernel<T, Functor, 2><<<grid, block, 0, stream>>>(
out_data, numel, main_offset, func);
break;
case 1:
VectorizedIndexKernel<T, Functor, 1><<<grid, block, 0, stream>>>(
out_data, numel, main_offset, func);
break;
default: {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
}
}
}

} // namespace operators
} // namespace paddle
141 changes: 13 additions & 128 deletions paddle/fluid/operators/uniform_random_inplace_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,148 +12,33 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/uniform_random_op.h"
#include "paddle/phi/kernels/full_kernel.h"

namespace paddle {
namespace operators {

template <typename T>
struct UniformGenerator {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
__host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num,
int diag_step, T diag_val)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val) {}

__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};

template <typename T>
struct UniformGeneratorOffset {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
int offset_;
__host__ __device__ UniformGeneratorOffset(T min, T max, int seed,
int diag_num, int diag_step,
T diag_val, int offset)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val),
offset_(offset) {}

__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n + offset_);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};

template <typename T>
__global__ void fill_value(int64_t size, T* data, float value) {
for (int idx = threadIdx.x; idx < size; idx += blockDim.x) {
data[idx] = static_cast<T>(value);
}
}

// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random as uniform_random_op.cu.
template <typename T>
class GPUUniformRandomInplaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto out_var = ctx.OutputVar("Out");
auto* tensor = out_var->GetMutable<framework::LoDTensor>();
T* data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}

T min = static_cast<T>(ctx.Attr<float>("min"));
T max = static_cast<T>(ctx.Attr<float>("max"));
unsigned int diag_num =
static_cast<unsigned int>(ctx.Attr<int>("diag_num"));
unsigned int diag_step =
static_cast<unsigned int>(ctx.Attr<int>("diag_step"));
T diag_val = static_cast<T>(ctx.Attr<float>("diag_val"));
thrust::counting_iterator<int64_t> index_sequence_begin(0);
int64_t size = tensor->numel();
int device_id = ctx.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num,
diag_step, diag_val, gen_offset));
} else {
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val));
}
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
UniformRandom<T>(context, tensor);
}
};

template <typename T>
class GPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
#endif
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* data = dx->mutable_data<T>(ctx.GetPlace());

auto size = dx->numel();
int64_t kBlockDim = std::min(size, kMaxBlockDim);
fill_value<T><<<1, kBlockDim, 0>>>(size, data, static_cast<float>(0));
auto dims = vectorize(dx->dims());
const auto& dev_cxt =
ctx.template device_context<platform::CUDADeviceContext>();
float value = static_cast<float>(0.0f);
phi::FullKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
paddle::platform::CUDADeviceContext>::TYPE&>(dev_cxt),
dims, value, phi::DataType::UNDEFINED, dx);
}
};

Expand Down
Loading