Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused_fc_elementwise_layernorm_op support fp16 #44710

Merged
merged 3 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"conv_elementwise_add_fuse_pass",
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"fc_fuse_pass"};
"fc_fuse_pass",
"fc_elementwise_layernorm_fuse_pass",
};

const std::vector<std::string> kTrtLowerPrecisionPasses{
// "conv_bn_fuse_pass",
Expand Down
303 changes: 262 additions & 41 deletions paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ limitations under the License. */
namespace cub = hipcub;
#endif

#if defined(PADDLE_WITH_CUDA)
#include <cuda_fp16.h>
#endif

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
Expand All @@ -28,9 +32,11 @@ namespace cub = hipcub;
namespace paddle {
namespace operators {

using float16 = phi::dtype::float16;

template <typename T>
static __device__ __forceinline__ T Relu(T x) {
return (x > 0) ? x : 0;
return static_cast<T>(fmaxf(0.f, x));
}

static __device__ __forceinline__ float RealSqrt(float x) { return sqrtf(x); }
Expand Down Expand Up @@ -137,6 +143,243 @@ __global__ void InplaceAddReluAddLayerNormKernel(const T* y,
}
}

template <bool DoRelu, int BlockDim>
__global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data,
const float16* bias_0_data,
const float16* bias_1_data,
const float16* scale_data,
float16* out_data,
float16* mean_data,
float16* variance_data,
int M,
int N,
float epsilon) {
#if defined(PADDLE_WITH_CUDA)
const half* y = reinterpret_cast<const half*>(y_data);
const half* bias_0 = reinterpret_cast<const half*>(bias_0_data);
const half* bias_1 = reinterpret_cast<const half*>(bias_1_data);
const half* scale = reinterpret_cast<const half*>(scale_data);
half* out = reinterpret_cast<half*>(out_data);
half* mean = reinterpret_cast<half*>(mean_data);
half* variance = reinterpret_cast<half*>(variance_data);
#else
const float16* y = y_data;
const float16* bias_0 = bias_0_data;
const float16* bias_1 = bias_1_data;
const float16* scale = scale_data;
float16* out = out_data;
float16* mean = mean_data;
float16* variance = variance_data;
#endif
using BlockReduce = cub::BlockReduce<PairForLayerNorm<float>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
#if defined(PADDLE_WITH_CUDA)
__shared__ half shared_mem[BlockDim + 2];
#else
__shared__ float16 shared_mem[BlockDim + 2];
#endif

for (int i = blockIdx.x; i < M; i += gridDim.x) {
int index = i * N + threadIdx.x;

// The fisrt BlockDim elements will be saved to shared memory.
int save_index = threadIdx.x;
#if defined(PADDLE_WITH_CUDA)
half* save_ptr = shared_mem;
#else
float16* save_ptr = shared_mem;
#endif
float sum_i = 0;
float square_sum_i = 0;
for (int j = threadIdx.x; j < N; j += blockDim.x) {
#if defined(PADDLE_WITH_CUDA)
half tmp_0 = out[index];
// Add bias
half tmp_1;
if (bias_0 != nullptr) {
tmp_1 = __hadd(tmp_0, bias_0[j]);
} else {
tmp_1 = tmp_0;
}
// Relu
half tmp_2 = DoRelu ? Relu(tmp_1) : tmp_1;
// elementwise_add
half tmp_3 = __hadd(tmp_2, y[index]);
#else
float16 tmp_0 = out[index];
// Add bias
float16 tmp_1 = bias_0 ? tmp_0 + bias_0[j] : tmp_0;
// Relu
float16 tmp_2 = DoRelu ? Relu(tmp_1) : tmp_1;
// elementwise_add
float16 tmp_3 = tmp_2 + y[index];
#endif
// Save
save_ptr[save_index] = tmp_3;
save_ptr = out;

index += blockDim.x;
save_index = index;

// For layer_norm, reduce to calculate mean and std
sum_i += static_cast<float>(tmp_3);
#if defined(PADDLE_WITH_CUDA) && __CUDA_ARCH__ >= 530
square_sum_i += static_cast<float>(__hmul(tmp_3, tmp_3));
#elif defined(PADDLE_WITH_CUDA)
square_sum_i += static_cast<float>(tmp_3) * static_cast<float>(tmp_3);
#else
square_sum_i += static_cast<float>(tmp_3 * tmp_3);
#endif
}
auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<float>(sum_i, square_sum_i),
PairForLayerNormAddFunctor<float>());
if (threadIdx.x == 0) {
#if defined(PADDLE_WITH_CUDA)
half mean_i = static_cast<half>(pair.first_ / N);
#if __CUDA_ARCH__ >= 530
half variance_i = static_cast<half>(
pair.second_ / N - static_cast<float>(__hmul(mean_i, mean_i)));
#else
half variance_i =
static_cast<half>(pair.second_ / N - static_cast<float>(mean_i) *
static_cast<float>(mean_i));
#endif
#else
float16 mean_i = static_cast<float16>(pair.first_ / N);
float16 variance_i = static_cast<float16>(
pair.second_ / N - static_cast<float>(mean_i * mean_i));
#endif
shared_mem[BlockDim] = mean_i;
shared_mem[BlockDim + 1] = variance_i;
if (mean) {
mean[blockIdx.x] = mean_i;
}
if (variance) {
variance[blockIdx.x] = variance_i;
}
}
__syncthreads();
#if defined(PADDLE_WITH_CUDA)
half mean_i = shared_mem[BlockDim];
half std_i = static_cast<half>(
RealSqrt(static_cast<float>(shared_mem[BlockDim + 1]) + epsilon));
#else
float16 mean_i = shared_mem[BlockDim];
float16 std_i = static_cast<float16>(
RealSqrt(static_cast<float>(shared_mem[BlockDim + 1]) + epsilon));
#endif

index = i * N + threadIdx.x;
// First BlockDim elements loading from shared memory.
save_index = threadIdx.x;
save_ptr = shared_mem;

// For layer_norm, calculate out
for (int j = threadIdx.x; j < N; j += blockDim.x) {
#if defined(PADDLE_WITH_CUDA)
#if __CUDA_ARCH__ >= 530
half tmp_0 = __hdiv(__hsub(save_ptr[save_index], mean_i), std_i);
half tmp_1 = scale ? __hmul(scale[j], tmp_0) : tmp_0;
#else
half tmp_0 = static_cast<float>(static_cast<float>(save_ptr[save_index]) +
static_cast<float>(mean_i) /
static_cast<float>(std_i));
half tmp_1 = scale ? static_cast<half>(static_cast<float>(scale[j]) *
static_cast<float>(tmp_0))
: tmp_0;
#endif
if (bias_1 != nullptr) {
out[index] = __hadd(tmp_1, bias_1[j]);
} else {
out[index] = tmp_1;
}
#else
float16 tmp_0 = (save_ptr[save_index] - mean_i) / std_i;
float16 tmp_1 = scale ? scale[j] * tmp_0 : tmp_0;
out[index] = bias_1 ? tmp_1 + bias_1[j] : tmp_1;
#endif
save_ptr = out;
index += blockDim.x;
save_index = index;
}
}
}

template <typename T>
void AddReluAddLayerNorm(gpuStream_t stream,
bool with_relu,
int max_threads,
const T* y,
const T* bias_0,
const T* bias_1,
const T* scale,
T* out,
T* mean,
T* variance,
int M,
int N,
float epsilon) {
if (with_relu) {
switch (platform::RoundToPowerOfTwo(N)) {
CUDA_LAUNCH_KERNEL_HELPER(
InplaceAddReluAddLayerNormKernel<T, true, kPowerOfTwoDim>
<<<std::max(max_threads / kPowerOfTwoDim, 1),
kPowerOfTwoDim,
0,
stream>>>(
y, bias_0, bias_1, scale, out, mean, variance, M, N, epsilon));
}
} else {
switch (platform::RoundToPowerOfTwo(N)) {
CUDA_LAUNCH_KERNEL_HELPER(
InplaceAddReluAddLayerNormKernel<T, false, kPowerOfTwoDim>
<<<std::max(max_threads / kPowerOfTwoDim, 1),
kPowerOfTwoDim,
0,
stream>>>(
y, bias_0, bias_1, scale, out, mean, variance, M, N, epsilon));
}
}
}

template <>
void AddReluAddLayerNorm(gpuStream_t stream,
bool with_relu,
int max_threads,
const float16* y,
const float16* bias_0,
const float16* bias_1,
const float16* scale,
float16* out,
float16* mean,
float16* variance,
int M,
int N,
float epsilon) {
if (with_relu) {
switch (platform::RoundToPowerOfTwo(N)) {
CUDA_LAUNCH_KERNEL_HELPER(
InplaceAddReluAddLayerNormKernel<true, kPowerOfTwoDim>
<<<std::max(max_threads / kPowerOfTwoDim, 1),
kPowerOfTwoDim,
0,
stream>>>(
y, bias_0, bias_1, scale, out, mean, variance, M, N, epsilon));
}
} else {
switch (platform::RoundToPowerOfTwo(N)) {
CUDA_LAUNCH_KERNEL_HELPER(
InplaceAddReluAddLayerNormKernel<false, kPowerOfTwoDim>
<<<std::max(max_threads / kPowerOfTwoDim, 1),
kPowerOfTwoDim,
0,
stream>>>(
y, bias_0, bias_1, scale, out, mean, variance, M, N, epsilon));
}
}
}

template <typename T>
class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -169,7 +412,6 @@ class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> {
static_cast<T>(0.0),
out_data,
N);

auto* y = ctx.Input<framework::Tensor>("Y");
auto* bias_0 = ctx.Input<framework::Tensor>("Bias0");
auto* bias_1 = ctx.Input<framework::Tensor>("Bias1");
Expand All @@ -192,49 +434,28 @@ class FusedFCElementwiseLayerNormOpKernel : public framework::OpKernel<T> {
float epsilon = ctx.Attr<float>("epsilon");

int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
if (with_relu) {
switch (platform::RoundToPowerOfTwo(N)) {
CUDA_LAUNCH_KERNEL_HELPER(
InplaceAddReluAddLayerNormKernel<T, true, kPowerOfTwoDim>
<<<std::max(max_threads / kPowerOfTwoDim, 1),
kPowerOfTwoDim,
0,
dev_ctx.stream()>>>(y_data,
bias_0_data,
bias_1_data,
scale_data,
out_data,
mean_data,
variance_data,
M,
N,
epsilon));
}
} else {
switch (platform::RoundToPowerOfTwo(N)) {
CUDA_LAUNCH_KERNEL_HELPER(
InplaceAddReluAddLayerNormKernel<T, false, kPowerOfTwoDim>
<<<std::max(max_threads / kPowerOfTwoDim, 1),
kPowerOfTwoDim,
0,
dev_ctx.stream()>>>(y_data,
bias_0_data,
bias_1_data,
scale_data,
out_data,
mean_data,
variance_data,
M,
N,
epsilon));
}
}
AddReluAddLayerNorm(dev_ctx.stream(),
with_relu,
max_threads,
y_data,
bias_0_data,
bias_1_data,
scale_data,
out_data,
mean_data,
variance_data,
M,
N,
epsilon);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fused_fc_elementwise_layernorm,
ops::FusedFCElementwiseLayerNormOpKernel<float>);
REGISTER_OP_CUDA_KERNEL(
fused_fc_elementwise_layernorm,
ops::FusedFCElementwiseLayerNormOpKernel<phi::dtype::float16>,
ops::FusedFCElementwiseLayerNormOpKernel<float>,
ops::FusedFCElementwiseLayerNormOpKernel<double>);