-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add new API/OP:paddle.Tensor.exponential_ (#38256)
* add new API/OP:paddle.Tensor.exponential_ * fix CI
- Loading branch information
1 parent
c396ee6
commit 3318500
Showing
7 changed files
with
684 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#ifdef __NVCC__ | ||
#include <curand_kernel.h> | ||
#endif | ||
#ifdef __HIPCC__ | ||
#include <hiprand_kernel.h> | ||
#endif | ||
|
||
#include "paddle/fluid/framework/tensor.h" | ||
#include "paddle/fluid/platform/device/gpu/gpu_info.h" | ||
#include "paddle/fluid/platform/device_context.h" | ||
#include "paddle/fluid/platform/for_range.h" | ||
#include "paddle/fluid/platform/hostdevice.h" | ||
|
||
namespace paddle { | ||
namespace distribution { | ||
|
||
using Tensor = framework::Tensor; | ||
|
||
template <typename T> | ||
struct exponential_transform { | ||
explicit exponential_transform(T lambda) : lambda_(lambda) {} | ||
|
||
HOSTDEVICE inline T operator()(T val) const { | ||
#if defined(__NVCC__) || defined(__HIPCC__) | ||
if (std::is_same<T, double>::value) { | ||
return static_cast<T>(-1.0) / lambda_ * log(val); | ||
} else { | ||
return static_cast<T>(-1.0) / lambda_ * __logf(val); | ||
} | ||
#else | ||
return static_cast<T>(-1.0) / lambda_ * std::log(static_cast<T>(1.0) - val); | ||
#endif | ||
} | ||
|
||
private: | ||
T lambda_; | ||
}; | ||
|
||
#if defined(__NVCC__) || defined(__HIPCC__) | ||
template <typename T> | ||
struct uniform_distribution; | ||
|
||
template <typename T> | ||
struct normal_distribution; | ||
|
||
#if defined(__NVCC__) | ||
template <> | ||
struct uniform_distribution<float> { | ||
__device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const { | ||
return curand_uniform4(state); | ||
} | ||
static constexpr int kReturnsCount = 4; | ||
}; | ||
|
||
template <> | ||
struct uniform_distribution<double> { | ||
__device__ inline double2 operator()( | ||
curandStatePhilox4_32_10_t *state) const { | ||
return curand_uniform2_double(state); | ||
} | ||
static constexpr int kReturnsCount = 2; | ||
}; | ||
|
||
template <> | ||
struct normal_distribution<float> { | ||
__device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const { | ||
return curand_normal4(state); | ||
} | ||
static constexpr int kReturnsCount = 4; | ||
}; | ||
|
||
template <> | ||
struct normal_distribution<double> { | ||
__device__ inline double2 operator()( | ||
curandStatePhilox4_32_10_t *state) const { | ||
return curand_normal2_double(state); | ||
} | ||
static constexpr int kReturnsCount = 2; | ||
}; | ||
|
||
#else | ||
template <> | ||
struct uniform_distribution<float> { | ||
__device__ inline float4 operator()( | ||
hiprandStatePhilox4_32_10_t *state) const { | ||
return hiprand_uniform4(state); | ||
} | ||
static constexpr int kReturnsCount = 4; | ||
}; | ||
|
||
template <> | ||
struct uniform_distribution<double> { | ||
__device__ inline double2 operator()( | ||
hiprandStatePhilox4_32_10_t *state) const { | ||
return hiprand_uniform2_double(state); | ||
} | ||
static constexpr int kReturnsCount = 2; | ||
}; | ||
|
||
template <> | ||
struct normal_distribution<float> { | ||
__device__ inline float4 operator()( | ||
hiprandStatePhilox4_32_10_t *state) const { | ||
return hiprand_normal4(state); | ||
} | ||
static constexpr int kReturnsCount = 4; | ||
}; | ||
|
||
template <> | ||
struct normal_distribution<double> { | ||
__device__ inline double2 operator()( | ||
hiprandStatePhilox4_32_10_t *state) const { | ||
return hiprand_normal2_double(state); | ||
} | ||
static constexpr int kReturnsCount = 2; | ||
}; | ||
#endif | ||
|
||
template <typename T, typename DistOp, typename TransformOp> | ||
__global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset, | ||
DistOp dist, TransformOp trans, | ||
T *out_data) { | ||
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x); | ||
int32_t returns_count = DistOp::kReturnsCount; | ||
#if defined(__NVCC__) | ||
curandStatePhilox4_32_10_t state; | ||
curand_init(seed, idx, offset, &state); | ||
#else | ||
hiprandStatePhilox4_32_10_t state; | ||
hiprand_init(seed, idx, offset, &state); | ||
#endif | ||
size_t total_thread = gridDim.x * blockDim.x; | ||
for (size_t i = idx; i < size; i += total_thread * returns_count) { | ||
auto random_tuple = dist(&state); | ||
for (size_t j = 0; j < returns_count; j++) { | ||
size_t index = i + j * total_thread; | ||
if (index < size) { | ||
auto random = static_cast<T>((&random_tuple.x)[j]); | ||
out_data[index] = trans(random); | ||
} | ||
} | ||
} | ||
} | ||
|
||
template <typename T, typename DistOp, typename TransformOp> | ||
void distribution_and_transform(const platform::CUDADeviceContext &dev_ctx, | ||
Tensor *out, DistOp dist, TransformOp trans) { | ||
T *out_data = out->mutable_data<T>(dev_ctx.GetPlace()); | ||
auto size = out->numel(); | ||
|
||
int64_t device_id = | ||
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId(); | ||
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); | ||
|
||
size_t block_size = 256; | ||
size_t expect_grid_size = (size + block_size - 1) / block_size; | ||
const auto &prop = platform::GetDeviceProperties(device_id); | ||
size_t max_grid_size = (prop.maxThreadsPerMultiProcessor / block_size) * | ||
prop.multiProcessorCount; | ||
size_t grid_size = | ||
expect_grid_size > max_grid_size ? max_grid_size : expect_grid_size; | ||
|
||
size_t total_thread = block_size * grid_size; | ||
size_t curand4_loop_times = | ||
(size + 4 * total_thread - 1) / (4 * total_thread); | ||
// 'increment' shoulde be multiple of 4 | ||
uint64_t increment = curand4_loop_times * 4; | ||
|
||
auto seed_offset = gen_cuda->IncrementOffset(increment); | ||
uint64_t seed = seed_offset.first; | ||
uint64_t offset = seed_offset.second; | ||
|
||
DistributionKernel< | ||
T, DistOp, TransformOp><<<grid_size, block_size, 0, dev_ctx.stream()>>>( | ||
size, seed, offset, dist, trans, out_data); | ||
} | ||
|
||
#endif | ||
|
||
} // namespace distribution | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/exponential_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class ExponentialOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext *ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExponentialOp"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExponentialOp"); | ||
auto dim = ctx->GetInputDim("X"); | ||
ctx->SetOutputDim("Out", dim); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext &ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class ExponentialOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddComment(R"DOC( | ||
This operator fills the input tensor with random values sampled from a | ||
exponential distribution. | ||
)DOC"); | ||
AddInput("X", "The input tensor."); | ||
AddOutput("Out", "The output tensor of exponential OP."); | ||
AddAttr<float>( | ||
"lambda", "lambd parameter of exponential distribution. [default 1.0].") | ||
.SetDefault(1.0f); | ||
} | ||
}; | ||
|
||
class ExponentialOpInferVarType | ||
: public framework::PassInDtypeAndVarTypeToOutput { | ||
protected: | ||
std::unordered_map<std::string, std::string> &GetInputOutputWithSameType() | ||
const override { | ||
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}}; | ||
return m; | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class ExponentialKernel<platform::CPUDeviceContext, T> | ||
: public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext &ctx) const override { | ||
auto *out = ctx.Output<framework::Tensor>("Out"); | ||
T *out_data = out->mutable_data<T>(ctx.GetPlace()); | ||
|
||
T lambda = static_cast<T>(ctx.Attr<float>("lambda")); | ||
int64_t size = out->numel(); | ||
|
||
auto gen = framework::DefaultCPUGenerator(); | ||
auto engine = gen->GetCPUEngine(); | ||
|
||
std::uniform_real_distribution<T> uniform(0.0, 1.0); | ||
distribution::exponential_transform<T> trans(lambda); | ||
for (int64_t i = 0; i < size; ++i) { | ||
out_data[i] = trans(uniform(*engine)); | ||
} | ||
} | ||
}; | ||
|
||
class ExponentialGradOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext *ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", | ||
"Out_Grad", "ExponentialGradOp"); | ||
|
||
auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out")); | ||
ctx->SetOutputDim(framework::GradVarName("X"), dout_dim); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class ExponentialGradOpMaker : public framework::SingleGradOpMaker<T> { | ||
public: | ||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; | ||
|
||
protected: | ||
void Apply(GradOpPtr<T> retv) const override { | ||
retv->SetType("exponential_grad"); | ||
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); | ||
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); | ||
retv->SetAttrMap(this->Attrs()); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
|
||
DECLARE_INPLACE_OP_INFERER(ExponentialInferer, {"X", "Out"}); | ||
DECLARE_INPLACE_OP_INFERER(ExponentialGradInferer, | ||
{paddle::framework::GradVarName("Out"), | ||
paddle::framework::GradVarName("X")}); | ||
|
||
REGISTER_OPERATOR(exponential, ops::ExponentialOp, ops::ExponentialOpMaker, | ||
ops::ExponentialOpInferVarType, | ||
ops::ExponentialGradOpMaker<paddle::framework::OpDesc>, | ||
ops::ExponentialGradOpMaker<paddle::imperative::OpBase>, | ||
ExponentialInferer); | ||
REGISTER_OPERATOR(exponential_grad, ops::ExponentialGradOp, | ||
ExponentialGradInferer); | ||
|
||
REGISTER_OP_CPU_KERNEL(exponential, | ||
ops::ExponentialKernel<plat::CPUDeviceContext, float>, | ||
ops::ExponentialKernel<plat::CPUDeviceContext, double>); | ||
REGISTER_OP_CPU_KERNEL( | ||
exponential_grad, ops::ExponentialGradKernel<plat::CPUDeviceContext, float>, | ||
ops::ExponentialGradKernel<plat::CPUDeviceContext, double>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/exponential_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class ExponentialKernel<platform::CUDADeviceContext, T> | ||
: public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
framework::Tensor* out = ctx.Output<framework::Tensor>("Out"); | ||
auto& dev_cxt = ctx.template device_context<platform::CUDADeviceContext>(); | ||
T lambda = static_cast<T>(ctx.Attr<float>("lambda")); | ||
|
||
distribution::uniform_distribution<T> dist; | ||
distribution::exponential_transform<T> trans(lambda); | ||
distribution::distribution_and_transform<T>(dev_cxt, out, dist, trans); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
exponential, ops::ExponentialKernel<plat::CUDADeviceContext, float>, | ||
ops::ExponentialKernel<plat::CUDADeviceContext, double>); | ||
REGISTER_OP_CUDA_KERNEL( | ||
exponential_grad, | ||
ops::ExponentialGradKernel<plat::CUDADeviceContext, float>, | ||
ops::ExponentialGradKernel<plat::CUDADeviceContext, double>); |
Oops, something went wrong.