diff --git a/.github/workflows/metax_work.yaml b/.github/workflows/metax_work.yaml index aff530d475c..f14023848c6 100644 --- a/.github/workflows/metax_work.yaml +++ b/.github/workflows/metax_work.yaml @@ -1,4 +1,4 @@ -name: padlle metax gpu test +name: paddle metax gpu test on: workflow_dispatch: diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index 5930eaaebd2..2bb282cf54f 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -326,6 +326,8 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/im2sequence_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/im2sequence_grad_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/increment_kernel.cu + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/embedding_grad_add_to_kernel.cu + # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/cross_entropy_bwd_w_downcast.cu # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu @@ -728,6 +730,11 @@ target_link_libraries( ${WARPCTC_LIBRARIES} ${WARPRNNT_LIBRARIES} ${PADDLE_CORE_LIB}) + +target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmccl.so) +target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmcFlashAttn.so) +target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmcpti.so) + include_directories(BEFORE ${PADDLE_SOURCE_DIR}) target_compile_definitions( diff --git a/backends/metax_gpu/kernels/cuda_kernels/cross_entropy_bwd_w_downcast.cu b/backends/metax_gpu/kernels/cuda_kernels/cross_entropy_bwd_w_downcast.cu new file mode 100644 index 00000000000..a0d5dfd7a5a --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/cross_entropy_bwd_w_downcast.cu @@ -0,0 +1,291 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/cross_entropy_grad_kernel.h" + +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include "kernels/gpudnn/softmax_gpudnn.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/softmax.h" + +namespace phi { + +/* + Vectorized wrapper of softmax with cross entropy grad hard label. + Optimized with float4 vectorization for memory coalescing and improved + throughput. +*/ +template +__global__ void SoftmaxWithCrossEntropyGradHardLabelVectorized( + LogitT* __restrict__ logits_grad, + const T* __restrict__ loss_grad, + const T* __restrict__ softmax, + const LabelT* __restrict__ labels, + const int64_t n, + const int64_t dim, + const int64_t d, + const int ignore_index) { + // Vectorized load/store with float4 for 128-bit memory transactions + constexpr int VEC_SIZE = 4; + using VecT = typename phi::AlignedVector; + using SoftmaxVecT = typename phi::AlignedVector; + + int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + int64_t vec_id = tid * VEC_SIZE; + + // Ensure we don't exceed bounds + if (vec_id >= n * dim * d) return; + + // Compute indices for vectorized access + int64_t idx_n = vec_id / (d * dim); + int64_t idx_dim_start = (vec_id / d) % dim; + int64_t idx_d = vec_id % d; + int64_t ids = idx_n * d + idx_d; + + // Load label once per thread + auto lbl = static_cast(labels[ids]); + + if (lbl == ignore_index) { + // Vectorized zero fill for ignore_index + VecT* vec_grad = reinterpret_cast(&logits_grad[vec_id]); + VecT zero_vec; +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + zero_vec.val[i] = static_cast(0.0f); + } + *vec_grad = zero_vec; + return; + } + + // Vectorized load of softmax values + SoftmaxVecT softmax_vec; + const SoftmaxVecT* softmax_ptr = + reinterpret_cast(&softmax[vec_id]); + softmax_vec = *softmax_ptr; + + // Load loss gradient (broadcast across vector elements) + T loss_grad_val = loss_grad[ids]; + + // Vectorized computation + VecT grad_vec; +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + int64_t current_dim = idx_dim_start + i; + if (current_dim < dim) { // Bounds check for partial vectors + float softmax_val = static_cast(softmax_vec.val[i]); + float grad_val; + + if (lbl == current_dim) { + grad_val = (softmax_val - 1.0f) * static_cast(loss_grad_val); + } else { + grad_val = softmax_val * static_cast(loss_grad_val); + } + + grad_vec.val[i] = static_cast(grad_val); + } else { + grad_vec.val[i] = static_cast(0.0f); + } + } + + // Vectorized store + VecT* grad_ptr = reinterpret_cast(&logits_grad[vec_id]); + *grad_ptr = grad_vec; +} + +/* + Specialized kernel for dimensions not divisible by vector size + Uses warp-level primitives for better performance on irregular sizes +*/ +template +__global__ void SoftmaxWithCrossEntropyGradHardLabelWarp( + LogitT* __restrict__ logits_grad, + const T* __restrict__ loss_grad, + const T* __restrict__ softmax, + const LabelT* __restrict__ labels, + const int64_t n, + const int64_t dim, + const int64_t d, + const int ignore_index) { + const int warps_per_block = 4; + const int threads_per_warp = 32; + const int threads_per_block = warps_per_block * threads_per_warp; + + int tid = blockIdx.x * threads_per_block + threadIdx.x; + int warp_id = threadIdx.x / threads_per_warp; + int lane_id = threadIdx.x % threads_per_warp; + + // Process multiple elements per thread using warp-level parallelism + int64_t elements_per_thread = + (n * dim * d + gridDim.x * threads_per_block - 1) / + (gridDim.x * threads_per_block); + + for (int e = 0; e < elements_per_thread; ++e) { + int64_t idx = tid + e * gridDim.x * threads_per_block; + if (idx >= n * dim * d) break; + + int64_t idx_n = idx / (d * dim); + int64_t idx_dim = (idx / d) % dim; + int64_t idx_d = idx % d; + int64_t ids = idx_n * d + idx_d; + + auto lbl = static_cast(labels[ids]); + + if (lbl == ignore_index) { + logits_grad[idx] = static_cast(0.0f); + } else if (lbl == idx_dim) { + logits_grad[idx] = + static_cast((static_cast(softmax[idx]) - 1.0f) * + static_cast(loss_grad[ids])); + } else { + logits_grad[idx] = + static_cast(static_cast(softmax[idx]) * + static_cast(loss_grad[ids])); + } + } +} + +/* + Optimized kernel selector based on problem size and alignment +*/ +template +void LaunchOptimizedCrossEntropyGradKernel(const GPUContext& dev_ctx, + LogitT* logits_grad, + const T* loss_grad, + const T* softmax, + const LabelT* labels, + const int64_t n, + const int64_t dim, + const int64_t d, + const int ignore_index) { + const int64_t total_elements = n * dim * d; + auto stream = dev_ctx.stream(); + + // Check alignment for vectorized kernel + bool is_aligned = (reinterpret_cast(logits_grad) % 16 == 0) && + (reinterpret_cast(softmax) % 16 == 0) && + (total_elements % 4 == 0); + + if (is_aligned && total_elements >= 1024) { + // Use vectorized kernel for aligned, large problems + constexpr int VEC_SIZE = 4; + const int threads_per_block = 256; + const int vec_elements = total_elements / VEC_SIZE; + const int blocks = + (vec_elements + threads_per_block - 1) / threads_per_block; + + SoftmaxWithCrossEntropyGradHardLabelVectorized + <<>>( + logits_grad, loss_grad, softmax, labels, n, dim, d, ignore_index); + } else { + // Use warp-specialized kernel for irregular sizes + const int warps_per_block = 4; + const int threads_per_block = warps_per_block * 32; + const int blocks = + std::min(1024, + static_cast((total_elements + threads_per_block - 1) / + threads_per_block)); + + SoftmaxWithCrossEntropyGradHardLabelWarp + <<>>( + logits_grad, loss_grad, softmax, labels, n, dim, d, ignore_index); + } +} + +template +void CrossEntropyWithSoftmaxBwdWithDowncastGPUKernel( + const GPUContext& dev_ctx, + const DenseTensor& label, + const DenseTensor& softmax, + const DenseTensor& loss_grad, + int axis, + DenseTensor* logits_grad) { + // PADDLE_ENFORCE_EQ( + // dev_ctx.GetPlace().GetType(), + // phi::AllocationType::GPU, + // common::errors::Unavailable("softmax_with_cross_entropy operator's " + // "CUDA kernel only runs on GPU device.")); + + using LogitT = phi::bfloat16; + const T* loss_grad_data = loss_grad.data(); + DenseTensor* logit_grad = logits_grad; + + LogitT* logit_grad_data = nullptr; + logit_grad_data = dev_ctx.template Alloc(logit_grad); + + const int rank = logit_grad->dims().size(); + const int axis_v = phi::funcs::CanonicalAxis(axis, rank); + int axis_dim = logit_grad->dims()[axis_v]; + + const int64_t n = phi::funcs::SizeToAxis(axis_v, logit_grad->dims()); + const int64_t d = phi::funcs::SizeFromAxis(axis_v, logit_grad->dims()); + const int64_t remain = d / axis_dim; + + const T* softmax_data = softmax.data(); + const auto* label_data = label.data(); + + // Launch optimized kernel with automatic selection + LaunchOptimizedCrossEntropyGradKernel(dev_ctx, + logit_grad_data, + loss_grad_data, + softmax_data, + label_data, + n, + axis_dim, + remain, + -100); +} + +template +void CrossEntropyWithSoftmaxBwdWithDowncastKernel(const Context& dev_ctx, + const DenseTensor& label, + const DenseTensor& softmax, + const DenseTensor& loss_grad, + DenseTensor* logits_grad) { + constexpr int axis = -1; + if (logits_grad->numel() == 0) { + dev_ctx.template Alloc(logits_grad); + return; + } + auto dtype = label.dtype(); + PD_VISIT_INTEGRAL_TYPES( + dtype, "CrossEntropyWithSoftmaxBwdWithDowncastGPUKernel", ([&] { + CrossEntropyWithSoftmaxBwdWithDowncastGPUKernel( + dev_ctx, label, softmax, loss_grad, axis, logits_grad); + })); +} + +} // namespace phi + +PD_REGISTER_PLUGIN_KERNEL(cross_entropy_with_softmax_bwd_w_downcast, + metax_gpu, + ALL_LAYOUT, + phi::CrossEntropyWithSoftmaxBwdWithDowncastKernel, + float, + double, + phi::float16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/embedding_grad_add_to_kernel.cu b/backends/metax_gpu/kernels/cuda_kernels/embedding_grad_add_to_kernel.cu new file mode 100644 index 00000000000..6b20feee0fd --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/embedding_grad_add_to_kernel.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/embedding_grad_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_grad.h" +#include "paddle/phi/kernels/gpu/embedding_grad_add_to_kernel.cu" // NOLINT + +PD_CUSTOM_KERNEL_REGISTER(embedding_grad_add_to, + metax_gpu, + ALL_LAYOUT, + phi::EmbeddingGradAddToAddToKernel, + float, + double, + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/moe_combine_no_weight_grad_kernel.cu b/backends/metax_gpu/kernels/cuda_kernels/moe_combine_no_weight_grad_kernel.cu new file mode 100644 index 00000000000..e6984cf86d2 --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/moe_combine_no_weight_grad_kernel.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/legacy/gpu/moe_combine_no_weight_grad_kernel.cu" // NOLINT + +PD_CUSTOM_KERNEL_REGISTER(moe_combine_no_weight_grad, + metax_gpu, + ALL_LAYOUT, + phi::MoeCombineNoWeightGradKernel, + float, + double, + phi::bfloat16, + phi::float16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/multihead_matmul_kernel.cu b/backends/metax_gpu/kernels/cuda_kernels/multihead_matmul_kernel.cu new file mode 100644 index 00000000000..151c929e41c --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/multihead_matmul_kernel.cu @@ -0,0 +1,433 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "kernels/funcs/blas/blas.h" +#include "paddle/common/errors.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h" + +namespace phi { +namespace fusion { + +template +__global__ void transpose(T *src, + T *dst, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head) { + int batch_id = blockIdx.x / (head_num * seq_len); + int seq_id = blockIdx.x % seq_len; + int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len; + dst[batch_id * (head_num * seq_len * size_per_head) + + seq_id * head_num * size_per_head + head_id * size_per_head + + threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x]; +} + +template +inline __device__ T add_func(T a, T b); + +template <> +__device__ float add_func(float a, float b) { + return a + b; +} + +template <> +__device__ float2 add_func(float2 a, float2 b) { + float2 c; + c.x = a.x + b.x; + c.y = a.y + b.y; + return c; +} + +template <> +__device__ float4 add_func(float4 a, float4 b) { + float4 c; + c.x = a.x + b.x; + c.y = a.y + b.y; + c.z = a.z + b.z; + c.w = a.w + b.w; + return c; +} +#if defined(PADDLE_WITH_CUDA) +template <> +__device__ half2 add_func(half2 a, half2 b) { +#if __CUDA_ARCH__ >= 530 + return __hadd2(a, b); +#else + return half2(__float2half(__half2float(a.x) + __half2float(b.x)), + __float2half(__half2float(b.x) + __half2float(b.y))); +#endif +} + +template <> +__device__ half add_func(half a, half b) { +#if __CUDA_ARCH__ >= 530 + return __hadd(a, b); +#else + return __float2half(__half2float(a) + __half2float(b)); +#endif +} +#endif + +template +__global__ void TransposeQkvKernel(const int H, + const T *input, + const T *bias, + T *output) { + // Input: BxSx3xNxH + // Bias: 3xNxH + // Output: 3xBxNxSxH + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; + + const int N = blockDim.y; + const int S = gridDim.x; + const int B = gridDim.y; + + const int NH = N * H; + const int NHS = NH * S; + const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3; + const int bias_offset = m * NH + n * H; + const int out_offset = s * H + n * S * H + b * NHS + m * NHS * B; + + const int i = threadIdx.x; + output[out_offset + i] = + add_func(input[in_offset + i], bias[bias_offset + i]); +} + +template +void TransQKVWithBias(const int batch, + const int seq_len, + const int head_size, + const int head_num, + const T *input, + const T *bias, + T *output, + gpuStream_t stream); + +template <> +void TransQKVWithBias(const int batch, + const int seq_len, + const int head_size, + const int head_num, + const float *input, + const float *bias, + float *output, + gpuStream_t stream) { + // BxSx3xNxH + 3xNxH -> 3xBxNxSxH + int scratch_size = batch * head_num * seq_len * seq_len; + const dim3 grid(seq_len, batch, 3); + // scratch % 4 == 0 to ensure the alignment + if (head_size % 4 == 0 && scratch_size % 4 == 0) { + const int h = head_size / 4; + const float4 *input4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *output4 = reinterpret_cast(output); + const dim3 block(h, head_num, 1); + + // limit h * head_num to max block size(1024). + PADDLE_ENFORCE_LE(h * head_num, + 1024, + common::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024 * 4)); + TransposeQkvKernel + <<>>(h, input4, bias4, output4); + } else if (head_size % 2 == 0 && scratch_size % 2 == 0) { + const int h = head_size / 2; + const float2 *input2 = reinterpret_cast(input); + const float2 *bias2 = reinterpret_cast(bias); + float2 *output2 = reinterpret_cast(output); + const dim3 block(h, head_num, 1); + // limit h * head_num to max block size(1024). + PADDLE_ENFORCE_LE(h * head_num, + 1024, + common::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024 * 2)); + TransposeQkvKernel + <<>>(h, input2, bias2, output2); + } else { + const dim3 block(head_size, head_num, 1); + // limit head_size * head_num to max block size(1024). + PADDLE_ENFORCE_LE(head_size * head_num, + 1024, + common::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024)); + TransposeQkvKernel + <<>>(head_size, input, bias, output); + } +} + +#if defined(PADDLE_WITH_CUDA) +template <> +void TransQKVWithBias(const int batch, + const int seq_len, + const int head_size, + const int head_num, + const phi::float16 *input, + const phi::float16 *bias, + phi::float16 *output, + gpuStream_t stream) { + // BxSx3xNxH + 3xNxH -> 3xBxNxSxH + int scratch_size = batch * head_num * seq_len * seq_len; + const dim3 grid(seq_len, batch, 3); + if (head_size % 2 == 0 && scratch_size % 2 == 0) { + const int h = head_size / 2; + const half2 *input2 = reinterpret_cast(input); + const half2 *bias2 = reinterpret_cast(bias); + half2 *output2 = reinterpret_cast(output); + const dim3 block(h, head_num, 1); + // limit h * head_num to max block size(1024). + PADDLE_ENFORCE_LE(h * head_num, + 1024, + common::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024 * 2)); + TransposeQkvKernel + <<>>(h, input2, bias2, output2); + } else { + const dim3 block(head_size, head_num, 1); + const half *input_half = reinterpret_cast(input); + const half *bias_half = reinterpret_cast(bias); + half *output_half = reinterpret_cast(output); + + // limit head_size * head_num to max block size(1024). + PADDLE_ENFORCE_LE(head_size * head_num, + 1024, + common::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, + head_size, + 1024)); + TransposeQkvKernel<<>>( + head_size, input_half, bias_half, output_half); + } +} +#endif + +inline int round_up(int seq_len, int multiple = 32) { + PADDLE_ENFORCE_GT( + multiple, + 0, + common::errors::InvalidArgument( + "multiple should be a positive number, but it's (%d)", multiple)); + return ((seq_len + multiple - 1) / multiple) * multiple; +} + +template +__global__ void broadcast(const T *src, + T *dst, + const int seq_len, + const int head_num) { + int batch_id = blockIdx.x / (head_num * seq_len); + int dst_offset = blockIdx.x * seq_len; + if (threadIdx.x < seq_len) { + dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len]; + } +} + +template +__global__ void broadcast_batch_head_number(const T *src, + T *dst, + const int batch_size, + const int seq_len, + const int head_num) { + int src_seq_id = blockIdx.x % seq_len; + int dst_offset = blockIdx.x * seq_len; + if (threadIdx.x < seq_len) { + dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_seq_id * seq_len]; + } +} + +template +void MultiheadMatmulKernel(const Context &dev_ctx, + const DenseTensor &input, + const DenseTensor &w, + const DenseTensor &bias, + const paddle::optional &bias_qk, + const bool transpose_q, + const bool transpose_k, + const bool transpose_v, + const float alpha, + const int head_number, + DenseTensor *out) { + auto *input_d = input.data(); + auto *w_d = w.data(); + auto *bias_d = bias.data(); + auto *bias_qk_d = bias_qk ? bias_qk->data() : nullptr; + T scale = static_cast(alpha); + + // compute q*k with eltadd + auto stream = dev_ctx.stream(); + // should be (B * S * hidden) + auto input_dims = input.dims(); + // shouble be (hidden * 3 * all_head_size) + auto w_dims = w.dims(); + int batch = input_dims[0]; + int seq_len = input_dims[1]; + int hidden = input_dims[2]; + phi::DenseTensor temp_bias_tensor; + // if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted + if (bias_qk && bias_qk->numel() == (batch * seq_len)) { + VLOG(4) << "Do broadcasted bias_qk from [batch, 1, 1, seq_len]"; + temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len}); + auto *temp_qk_bias = dev_ctx.template Alloc( + &temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T)); + int grid = batch * head_number * seq_len; + int block = round_up(seq_len); + broadcast<<>>( + bias_qk_d, temp_qk_bias, seq_len, head_number); + bias_qk_d = static_cast(temp_qk_bias); + } + // if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be + // broadcasted + if (bias_qk && bias_qk->numel() == (1 * seq_len * seq_len)) { + VLOG(4) << "do broadcasted bias_qk from [1, 1, seq_len, seq_len]"; + temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len}); + auto *temp_qk_bias = dev_ctx.template Alloc( + &temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T)); + int grid = batch * head_number * seq_len; + int block = round_up(seq_len); + broadcast_batch_head_number<<>>( + bias_qk_d, temp_qk_bias, batch, seq_len, head_number); + bias_qk_d = static_cast(temp_qk_bias); + } + if (!bias_qk) { + int size = batch * head_number * seq_len * seq_len; + temp_bias_tensor.Resize({size}); + auto *temp_qk_bias = dev_ctx.template Alloc( + &temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T)); +#ifdef PADDLE_WITH_HIP + hipMemset(temp_qk_bias, 0, sizeof(float) * size); +#else + cudaMemset(temp_qk_bias, 0, sizeof(float) * size); +#endif + bias_qk_d = static_cast(temp_qk_bias); + } + int all_head_size = w_dims[2]; + int head_size = all_head_size / head_number; + + out->Resize({batch, seq_len, all_head_size}); + auto *output_d = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + + // (B*S, hidden) + const phi::DenseTensor input_matrix = + phi::ReshapeToMatrix(input, 2 /*x_num_col_dims */); + // (hidden, 3 * all_head_size) + const phi::DenseTensor w_matrix = + phi::ReshapeToMatrix(w, 1 /*y_num_col_dims*/); + + phi::DenseTensor temp_out_tensor; + auto temp_out_dims = + common::make_ddim({batch, seq_len, 3, head_number, head_size}); + temp_out_tensor.Resize( + {batch * seq_len, common::product(temp_out_dims) / (batch * seq_len)}); + auto *temp_out_data = dev_ctx.template Alloc( + &temp_out_tensor, temp_out_tensor.numel() * sizeof(T)); + + // (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H) + auto blas = phi::funcs::GetBlas(dev_ctx); + blas.MatMul(input_matrix, w_matrix, &temp_out_tensor); + VLOG(2) << "(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)"; + // temp_out_tensor.Resize(temp_out_dims); + + phi::DenseTensor multihead_temp_tensor; + // B * head_number * S * S * 1 + B * S * 3 * N * H + int scratch_size = batch * head_number * seq_len * seq_len * 1; + multihead_temp_tensor.Resize({scratch_size + temp_out_tensor.numel()}); + auto *multihead_temp_data = dev_ctx.template Alloc( + &multihead_temp_tensor, multihead_temp_tensor.numel() * sizeof(T)); + + auto *qkptr = multihead_temp_data; + auto *tptr = multihead_temp_data + scratch_size; + + // Do the transpose with bias. + // BxSx3xNxH => tptr: 3xBxNxSxH. + TransQKVWithBias(batch, + seq_len, + head_size, + head_number, + temp_out_data, + bias_d, + tptr, + stream); + if (std::is_same::value) { + phi::funcs::MultiheadGPUComputeFunctor multihead_compute_func; + multihead_compute_func(dev_ctx, + batch, + seq_len, + head_number, + head_size, + reinterpret_cast(qkptr), + reinterpret_cast(bias_qk_d), + false, + reinterpret_cast(tptr), + __float2half(static_cast(scale)), + __float2half(0.0)); + } else { + phi::funcs::MultiheadGPUComputeFunctor multihead_compute_func; + multihead_compute_func(dev_ctx, + batch, + seq_len, + head_number, + head_size, + qkptr, + bias_qk_d, + false, + tptr, + scale, + T(0.0)); + } + + int grid = batch * head_number * seq_len; + int block = head_size; + transpose<<>>( + tptr, output_d, batch, seq_len, head_number, head_size); +} + +} // namespace fusion +} // namespace phi + +#if defined(PADDLE_WITH_CUDA) +PD_REGISTER_PLUGIN_KERNEL(multihead_matmul, + metax_gpu, + ALL_LAYOUT, + phi::fusion::MultiheadMatmulKernel, + float, + phi::float16) {} +#else +PD_REGISTER_PLUGIN_KERNEL(multihead_matmul, + metax_gpu, + ALL_LAYOUT, + phi::fusion::MultiheadMatmulKernel, + float) {} +#endif diff --git a/backends/metax_gpu/kernels/funcs/generator.cc b/backends/metax_gpu/kernels/funcs/generator.cc new file mode 100644 index 00000000000..8fcbf474b07 --- /dev/null +++ b/backends/metax_gpu/kernels/funcs/generator.cc @@ -0,0 +1,287 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/generator.h" + +#include + +#include +#include +#include + +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/backends/xpu/xpu_info.h" +#include "paddle/phi/core/enforce.h" + +static uint64_t GetRandomSeed() { + std::random_device rd; + // double has 53 bit significant, so limit uint64 to 53 bits + return ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF; +} + +namespace phi { + +const std::shared_ptr& DefaultXPUGenerator(int64_t device_id) { +#if defined(PADDLE_WITH_XPU) + + static int64_t num_xpu_devices = -1; + static std::once_flag num_devices_init_flag; + static std::deque xpu_device_flags; + static std::vector> default_xpu_generators; + + std::call_once(num_devices_init_flag, []() { + num_xpu_devices = phi::backends::xpu::GetXPUDeviceCount(); + xpu_device_flags.resize(num_xpu_devices); + default_xpu_generators.resize(num_xpu_devices); + }); + if (device_id < 0) { + PADDLE_THROW(common::errors::InvalidArgument( + "xpu device id should be greater than 0")); + } + + std::call_once(xpu_device_flags[device_id], [device_id]() { + default_xpu_generators[device_id] = + std::make_shared(GetRandomSeed(), device_id); + VLOG(4) << "initial seed: " + << default_xpu_generators[device_id]->GetCurrentSeed(); + }); + return default_xpu_generators[device_id]; +#else + PADDLE_THROW(common::errors::PermissionDenied( + "getDefaultXPUGenerator only support in XPU place")); +#endif +} + +const std::shared_ptr& DefaultCUDAGenerator(int64_t device_id) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + + static int64_t num_cuda_devices = -1; + static std::once_flag num_devices_init_flag; + static std::deque cuda_device_flags; + static std::vector> default_cuda_generators; + + std::call_once(num_devices_init_flag, []() { + num_cuda_devices = phi::backends::gpu::GetGPUDeviceCount(); + cuda_device_flags.resize(num_cuda_devices); + default_cuda_generators.resize(num_cuda_devices); + }); + if (device_id < 0) { + PADDLE_THROW(common::errors::InvalidArgument( + "cuda device id should be greater than 0")); + } + + std::call_once(cuda_device_flags[device_id], [device_id]() { + default_cuda_generators[device_id] = + std::make_shared(GetRandomSeed(), device_id); + VLOG(7) << "initial seed: " + << default_cuda_generators[device_id]->GetCurrentSeed(); + }); + return default_cuda_generators[device_id]; +#else + PADDLE_THROW(common::errors::PermissionDenied( + "getDefaultCUDAGenerator only support in CUDA place")); +#endif +} + +const std::shared_ptr& DefaultCPUGenerator() { + static auto default_cpu_generator = + std::make_shared(GetRandomSeed()); + return default_cpu_generator; +} + +const std::shared_ptr& DefaultCustomDeviceGenerator( + const phi::CustomPlace& place) { + static std:: + unordered_map, phi::Place::Hash> + generators; + if (generators.find(place) == generators.end()) { + generators.insert({place, std::make_shared(GetRandomSeed())}); + } + return generators[place]; +} + +using RNGMap = std::unordered_map>; + +static RNGMap& GetRandomSeedGeneratorMap() { + static auto random_seed_generator_map = RNGMap(); + return random_seed_generator_map; +} + +const std::shared_ptr& SetRandomSeedGenerator( + const std::string& name, uint64_t seed) { + auto& rng_map = GetRandomSeedGeneratorMap(); + auto iter = rng_map.find(name); + PADDLE_ENFORCE_EQ(iter == rng_map.end(), + true, + common::errors::AlreadyExists( + "%s RandomSeedGenerator is already exist", name)); + + auto generator = std::make_shared(seed); + bool emplace_success = rng_map.emplace(name, generator).second; + PADDLE_ENFORCE_EQ( + emplace_success, + true, + common::errors::PermissionDenied( + "SetRandomSeedGenerator cannot emplace %s RandomSeedGenerator", + name)); + return rng_map[name]; +} + +const std::shared_ptr& GetRandomSeedGenerator( + const std::string& name) { + auto& rng_map = GetRandomSeedGeneratorMap(); + auto iter = rng_map.find(name); + PADDLE_ENFORCE_EQ(iter != rng_map.end(), + true, + common::errors::NotFound( + "%s RandomSeedGenerator is not found, please " + "use `set_random_seed_generator` to set rng first", + name)); + return iter->second; +} + +// There are 3 conditions: +// (1) op seed is set, use op seed. +// (2) op seed is not set, global seed is set, use global seed. +// (3) op seed is not set, global seed is not set too, use random seed from +// RandomGenerator. +std::shared_ptr GetCPURandomEngine(uint64_t seed) { + if (seed == 0) { + VLOG(4) << "Use random cpu_engine from generator"; + return DefaultCPUGenerator()->GetCPUEngine(); + } else { + // NOTE(zhiqiu): creating an cpu_engine instance everytime instead of using + // OpDefaultCPUEngine(), this is the legacy behavior of random operators. + // The benefit is that when running PE with fixed-seed in multiple threads, + // each thread has their own cpu_engine, and doesn't affect each other. + // + // And we need to measure the determinacy of Generator in PE. + auto cpu_engine = std::make_shared(); + static std::mutex mu_; + { + std::lock_guard lock(mu_); + cpu_engine->seed(seed); + } + return cpu_engine; + } +} + +inline void Generator::print_state_info() { + VLOG(7) << "Generator Random state " + << "device id: " << state().device << ", seed: " << state().seed + << ", offset: " << state().offset << ", cpu_engine: " << cpu_engine(); +} + +Generator::Generator() { + auto seed = GetRandomSeed(); + current_index = states_.size(); + states_.emplace_back(-1, seed); + print_state_info(); +} + +Generator::Generator(uint64_t seed) { + current_index = states_.size(); + states_.emplace_back(-1, seed); + print_state_info(); +} + +Generator::Generator(uint64_t seed, int64_t device_id) { + current_index = states_.size(); + // device id first, then seed + states_.emplace_back(device_id, seed); + print_state_info(); +} + +phi::Generator::GeneratorState Generator::GetState() { return state(); } + +void Generator::SetState(const phi::Generator::GeneratorState& state) { + std::lock_guard lock(mu_); + if (current_index < states_.size()) + states_[current_index] = state; + else + PADDLE_THROW(common::errors::NotFound("Generator index is not found")); + print_state_info(); +} + +uint64_t Generator::GetStateIndex() { return current_index; } + +void Generator::SetStateIndex(uint64_t StateIndex) { + std::lock_guard lock(mu_); + if (current_index < states_.size()) + current_index = StateIndex; + else + PADDLE_THROW(common::errors::NotFound("Generator index is not found")); +} + +uint64_t Generator::RegisterStateIndex(const GeneratorState& state) { + std::lock_guard lock(mu_); + auto new_index = states_.size(); + states_.push_back(state); + current_index = new_index; + return new_index; +} + +inline Generator::GeneratorState& Generator::state() { + if (current_index < states_.size()) + return states_[current_index]; + else + PADDLE_THROW(common::errors::NotFound("Generator index is not found")); +} + +inline std::shared_ptr Generator::cpu_engine() { + return state().cpu_engine; +} + +uint64_t Generator::GetCurrentSeed() { + std::lock_guard lock(mu_); + return state().seed; +} + +uint64_t Generator::Seed() { + std::lock_guard lock(mu_); + uint64_t seed = GetRandomSeed(); + state().reset(seed); + return seed; +} + +void Generator::SetCurrentSeed(uint64_t seed) { + std::lock_guard lock(mu_); + state().reset(seed); +} + +std::shared_ptr Generator::GetCPUEngine() { + return cpu_engine(); +} + +uint64_t Generator::Random64() { + std::lock_guard lock(mu_); + auto current_engine = cpu_engine(); + return (*current_engine)(); +} + +std::pair Generator::IncrementOffset(uint64_t increment) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU) + std::lock_guard lock(mu_); + uint64_t offset = state().offset; + state().offset = offset + increment; + print_state_info(); + return std::make_pair(state().seed, offset); +#else + PADDLE_THROW(common::errors::PermissionDenied( + "Increment Offset only support in CUDA place")); +#endif +} + +} // namespace phi diff --git a/backends/metax_gpu/kernels/metax_kernel/cudnn_lstm_grad_kernel.cu b/backends/metax_gpu/kernels/metax_kernel/cudnn_lstm_grad_kernel.cu new file mode 100644 index 00000000000..766d984a25b --- /dev/null +++ b/backends/metax_gpu/kernels/metax_kernel/cudnn_lstm_grad_kernel.cu @@ -0,0 +1,362 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/metax_kernel/metax_context.h" //NOLINT +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cudnn_lstm_grad_kernel.h" +#include "paddle/phi/kernels/gpu/cudnn_lstm_utils.h" + +namespace phi { + +template +void CudnnLSTMGradKernel( + const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &init_h, + const DenseTensor &init_c, + const paddle::optional> &weight_list, + const paddle::optional &sequence_length, + const DenseTensor &out, + const DenseTensor &reserve, + const DenseTensor &state_out, + const DenseTensor &out_grad, + const DenseTensor &last_h_grad, + const DenseTensor &last_c_grad, + float dropout_prob, + bool is_bidirec, + int hidden_size, + int num_layers, + bool is_test, + int seed, + DenseTensor *x_grad, + DenseTensor *init_h_grad, + DenseTensor *init_c_grad, + std::vector weight_grad_list) { + auto input_dims = x.dims(); + auto init_h_dims = init_h.dims(); + auto init_c_dims = init_c.dims(); + + auto *init_h_data = init_h.data(); + auto *init_c_data = init_c.data(); + auto *out_data = out.data(); + auto *out_grad_data = out_grad.data(); + auto *last_h_grad_data = last_h_grad.data(); + auto *last_c_grad_data = last_c_grad.data(); + + auto running_weight_list = *weight_list.get_ptr(); + int weight_numel = size_sum(running_weight_list); + bool continuous = is_continuous>( + running_weight_list); + + // auto handle = dev_ctx.cudnn_handle(); + auto handle = GetDnnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); + auto place = dev_ctx.GetPlace(); + auto stream = dev_ctx.stream(); + phi::DenseTensor weight_whole; + T *weight_data = nullptr; + + if (!continuous) { + weight_whole.Resize({weight_numel}); + dev_ctx.template Alloc(&weight_whole); + weight_to_tensor(place, stream, running_weight_list, &weight_whole); + weight_data = weight_whole.data(); + } else { + weight_data = const_cast(running_weight_list[0]->data()); + } + + phi::DenseTensor weight_grad; + phi::funcs::SetConstant zero; + weight_grad.Resize({weight_numel}); + dev_ctx.template Alloc(&weight_grad); + zero(dev_ctx, &weight_grad, static_cast(0.0)); + T *weight_grad_data = weight_grad.data(); + + int offset = 0; + for (size_t i = 0; i < weight_grad_list.size(); ++i) { + size_t len = weight_grad_list[i]->numel(); + auto dim = weight_grad_list[i]->dims(); + weight_grad_list[i] + ->ShareDataWith(weight_grad.Slice(static_cast(offset), + static_cast(offset + len))) + .Resize(dim); + offset += len; + } + + x_grad->Resize(input_dims); + dev_ctx.template Alloc(x_grad); + auto *in_grad_data = x_grad->data(); + + if (init_h_grad) { + init_h_grad->Resize(init_h_dims); + dev_ctx.template Alloc(init_h_grad); + } + auto *init_h_grad_data = init_h_grad ? init_h_grad->data() : nullptr; + + if (init_c_grad) { + init_c_grad->Resize(init_c_dims); + dev_ctx.template Alloc(init_c_grad); + } + auto *init_c_grad_data = init_c_grad ? init_c_grad->data() : nullptr; + + auto running_seq_length = sequence_length.get_ptr(); + bool has_seq_length = running_seq_length != nullptr; + std::vector SequenceLength; + if (has_seq_length) { + SequenceLength = phi::GetVectorFromTensor(running_seq_length); + } + + int seq_length = input_dims[0]; + int batch_size = x.dims()[1]; + int input_size = x.dims()[2]; + + size_t workspace_size; + size_t reserve_size; + + ScopedRNNBase rnn(seq_length, + batch_size, + input_size, + hidden_size, + num_layers, + dropout_prob, + seed, + weight_numel, + true, + is_bidirec); + + rnn.Create(handle, + dev_ctx.GetPlace(), + SequenceLength, + &workspace_size, + &reserve_size, + const_cast(&state_out)); + + phi::DenseTensor workspace_data_; + workspace_data_.Resize({static_cast(workspace_size)}); + dev_ctx.template Alloc(&workspace_data_); + const uint8_t *reserve_data = reserve.data(); + +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardData_v8( + handle, + rnn.rnn_desc(), + nullptr, + rnn.y_seq_desc(), + out_data, + out_grad_data, + rnn.x_seq_desc(), + in_grad_data, + rnn.init_h_desc(), + init_h_data, + last_h_grad_data, + init_h_grad_data, + rnn.init_c_desc(), + init_c_data, + last_c_grad_data, + init_c_grad_data, + rnn.weights_size(), + weight_data, + workspace_size, + workspace_data_.data(), + reserve_size, + const_cast(reserve_data))); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardWeights_v8( + handle, + rnn.rnn_desc(), + CUDNN_WGRAD_MODE_ADD, + nullptr, + rnn.x_seq_desc(), + x.data(), + rnn.init_h_desc(), + init_h.data(), + rnn.y_seq_desc(), + out.data(), + rnn.weights_size(), + weight_grad_data, + workspace_size, + workspace_data_.data(), + reserve_size, + const_cast(reserve_data))); +#else + + if (!has_seq_length) { +// This interface is used when the input/output is unpadded. +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenRNNBackwardData(handle, + rnn.rnn_desc(), + seq_length, + rnn.y_descs(), + out_data, + rnn.y_descs(), + out_grad_data, + rnn.last_h_desc(), + last_h_grad_data, + rnn.last_c_desc(), + last_c_grad_data, + rnn.weight_desc(), + weight_data, + rnn.init_h_desc(), + init_h_data, + rnn.init_c_desc(), + init_c_data, + rnn.x_descs(), + in_grad_data, + rnn.init_h_desc(), + init_h_grad_data, + rnn.init_c_desc(), + init_c_grad_data, + workspace_data_.data(), + workspace_size, + const_cast(reserve_data), + reserve_size)); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenRNNBackwardWeights( + handle, + rnn.rnn_desc(), + seq_length, + rnn.x_descs(), + x.data(), + rnn.init_h_desc(), + init_h.data(), + rnn.y_descs(), + out.data(), + rnn.weight_desc(), + weight_grad_data, + workspace_data_.data(), + workspace_size, + const_cast(reserve_data), + reserve_size)); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnRNNBackwardData(handle, + rnn.rnn_desc(), + seq_length, + rnn.y_descs(), + out_data, + rnn.y_descs(), + out_grad_data, + rnn.last_h_desc(), + last_h_grad_data, + rnn.last_c_desc(), + last_c_grad_data, + rnn.weight_desc(), + weight_data, + rnn.init_h_desc(), + init_h_data, + rnn.init_c_desc(), + init_c_data, + rnn.x_descs(), + in_grad_data, + rnn.init_h_desc(), + init_h_grad_data, + rnn.init_c_desc(), + init_c_grad_data, + workspace_data_.data(), + workspace_size, + const_cast(reserve_data), + reserve_size)); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardWeights( + handle, + rnn.rnn_desc(), + seq_length, + rnn.x_descs(), + x.data(), + rnn.init_h_desc(), + init_h.data(), + rnn.y_descs(), + out.data(), + workspace_data_.data(), + workspace_size, + rnn.weight_desc(), + weight_grad_data, + const_cast(reserve_data), + reserve_size)); +#endif + } else { +#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201 + // for train + // This interface is used when the input/output is padded. + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardDataEx( + handle, + rnn.rnn_desc(), + rnn.y_seq_desc(), + out_data, + rnn.y_seq_desc(), + out_grad_data, + nullptr, + nullptr, + rnn.last_h_desc(), + last_h_grad_data, + rnn.last_c_desc(), + last_c_grad_data, + rnn.weight_desc(), + weight_data, + rnn.init_h_desc(), + init_h_data, + rnn.init_c_desc(), + init_c_data, + rnn.x_seq_desc(), + in_grad_data, + rnn.init_h_desc(), + init_h_grad_data, + rnn.init_c_desc(), + init_c_grad_data, + nullptr, + nullptr, + workspace_data_.data(), + workspace_size, + const_cast(reserve_data), + reserve_size)); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardWeightsEx( + handle, + rnn.rnn_desc(), + rnn.x_seq_desc(), + x.data(), + rnn.init_h_desc(), + init_h.data(), + rnn.y_seq_desc(), + out.data(), + workspace_data_.data(), + workspace_size, + rnn.weight_desc(), + weight_grad_data, + const_cast(reserve_data), + reserve_size)); +#else + PADDLE_THROW(common::errors::Unavailable( + "The padded input of rnn is supported by cudnnRNNBackwardDataEx, " + "cudnnRNNBackwardWeightsEx, but it only works when the version " + "of cudnn is larger than 7.2.1")); +#endif + } + +#endif // end CUDNN_VERSION >= 90000 +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL( + cudnn_lstm_grad, GPU, ALL_LAYOUT, phi::CudnnLSTMGradKernel, float) {} +#else +PD_REGISTER_PLUGIN_KERNEL(cudnn_lstm_grad, + metax_gpu, + ALL_LAYOUT, + phi::CudnnLSTMGradKernel, + float, + double) {} +#endif diff --git a/backends/metax_gpu/kernels/metax_kernel/cudnn_lstm_kernel.cu b/backends/metax_gpu/kernels/metax_kernel/cudnn_lstm_kernel.cu new file mode 100644 index 00000000000..6bb94c9281a --- /dev/null +++ b/backends/metax_gpu/kernels/metax_kernel/cudnn_lstm_kernel.cu @@ -0,0 +1,428 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "kernels/metax_kernel/metax_context.h" //NOLINT +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/cudnn_lstm_kernel.h" +#include "paddle/phi/kernels/gpu/cudnn_lstm_utils.h" + +namespace phi { + +template +#ifdef PADDLE_WITH_HIP +void LSTMInference(const bool &has_seq_length, + const miopenHandle_t &handle, +#else +void LSTMInference(const bool &has_seq_length, + const cudnnHandle_t &handle, +#endif + const int &seq_length, + ScopedRNNBase *rnn, + const T *x_data, + const T *init_h_data, + const T *init_c_data, + const T *w_data, + T *out_data, + T *last_h_data, + T *last_c_data, + phi::DenseTensor *workspace_data, + const size_t &workspace_size) { +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnRNNForward(handle, + rnn->rnn_desc(), + CUDNN_FWD_MODE_INFERENCE, + nullptr, + rnn->x_seq_desc(), + x_data, + rnn->y_seq_desc(), + out_data, + rnn->init_h_desc(), + init_h_data, + last_h_data, + rnn->init_c_desc(), + init_c_data, + last_c_data, + rnn->weights_size(), + w_data, + workspace_size, + workspace_data->data(), + 0, + nullptr)); + +#else + + if (!has_seq_length) { +// for inference +// This interface is used when the input/output is unpadded. +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::miopenRNNForwardInference(handle, + rnn->rnn_desc(), + seq_length, + rnn->x_descs(), + x_data, + rnn->init_h_desc(), + init_h_data, + rnn->init_c_desc(), + init_c_data, + rnn->weight_desc(), + w_data, + rnn->y_descs(), + out_data, + rnn->last_h_desc(), + last_h_data, + rnn->last_c_desc(), + last_c_data, + workspace_data->data(), + workspace_size)); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnRNNForwardInference(handle, + rnn->rnn_desc(), + seq_length, + rnn->x_descs(), + x_data, + rnn->init_h_desc(), + init_h_data, + rnn->init_c_desc(), + init_c_data, + rnn->weight_desc(), + w_data, + rnn->y_descs(), + out_data, + rnn->last_h_desc(), + last_h_data, + rnn->last_c_desc(), + last_c_data, + workspace_data->data(), + workspace_size)); +#endif + } else { +#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201 + // for inference + // This interface is used when the input/output is padded. + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNForwardInferenceEx( + handle, + rnn->rnn_desc(), + rnn->x_seq_desc(), + x_data, + rnn->init_h_desc(), + init_h_data, + rnn->init_c_desc(), + init_c_data, + rnn->weight_desc(), + w_data, + rnn->y_seq_desc(), + out_data, + rnn->last_h_desc(), + last_h_data, + rnn->last_c_desc(), + last_c_data, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + workspace_data->data(), + workspace_size)); +#else + // CUDNN VERSION has to >=7.2.1 + PADDLE_THROW(common::errors::Unavailable( + "The padded input is supported by " + "cudnnRNNForwardInferenceEx, but it only works when " + "the version of cudnn is larger than 7.2.1")); +#endif + } + +#endif // end CUDNN_VERSION >= 90000 +} + +template +void CudnnLSTMKernel( + const Context &dev_ctx, + const DenseTensor &x, + const DenseTensor &init_h, + const DenseTensor &init_c, + const paddle::optional &w, + const paddle::optional> &weight_list, + const paddle::optional &sequence_length, + float dropout_prob, + bool is_bidirec, + int hidden_size, + int num_layers, + bool is_test, + int seed, + DenseTensor *out, + DenseTensor *last_h, + DenseTensor *last_c, + DenseTensor *reserve, + DenseTensor *state_out) { + const T *x_data = x.data(); + const T *init_h_data = init_h.data(); + const T *init_c_data = init_c.data(); + + T *out_data = dev_ctx.template Alloc(out); + T *last_h_data = dev_ctx.template Alloc(last_h); + T *last_c_data = dev_ctx.template Alloc(last_c); + + if (!is_test) { + if (seed == 0) { + // If not specify seed, use global Generator to generate seed. + int device_id = dev_ctx.GetPlace().GetDeviceId(); + auto gen_cuda = phi::DefaultCUDAGenerator(device_id); + seed = static_cast(gen_cuda->Random64()); + } + } + + auto *running_sequence_length = sequence_length.get_ptr(); + bool has_seq_length = running_sequence_length != nullptr; + std::vector SequenceLength; + if (has_seq_length) { + SequenceLength = phi::GetVectorFromTensor(running_sequence_length); + } + + // auto handle = dev_ctx.cudnn_handle(); + auto handle = GetDnnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); + + int seq_length = x.dims()[0]; + int batch_size = x.dims()[1]; + int input_size = x.dims()[2]; + bool state_initialized = state_out->initialized() ? true : false; + + size_t workspace_size; + size_t reserve_size; + phi::DenseTensor weight_whole; + T *w_data = nullptr; + int weight_numel; + bool w_initialized = false; + auto place = dev_ctx.GetPlace(); + auto stream = dev_ctx.stream(); + auto *running_w = w.get_ptr(); + if (is_test && running_w != nullptr) { + w_initialized = running_w->initialized() ? true : false; + weight_numel = running_w->numel(); + } + if (!w_initialized) { + auto running_weight_list = *weight_list.get_ptr(); + bool continuous = is_continuous>( + running_weight_list); + weight_numel = size_sum(running_weight_list); + + if (!continuous) { + LOG_FIRST_N(WARNING, 2) + << "If the memory space of the Input WeightList is not continuous, " + "less efficient calculation will be called. Please call " + "flatten_parameters() to make the input memory continuous."; + weight_whole.Resize({weight_numel}); + dev_ctx.template Alloc(&weight_whole); + weight_to_tensor(place, stream, running_weight_list, &weight_whole); + w_data = weight_whole.data(); + if (is_test) { // maybe also reset small weights' ptr for training + int offset = 0; + for (size_t i = 0; i < running_weight_list.size(); ++i) { + size_t len = running_weight_list[i]->numel(); + auto dim = running_weight_list[i]->dims(); + const_cast(running_weight_list[i]) + ->ShareDataWith( + weight_whole.Slice(static_cast(offset), + static_cast(offset + len))) + .Resize(dim); + offset += len; + } + } + } else { + w_data = const_cast(running_weight_list[0]->data()); + } + } else { + w_data = const_cast(running_w->data()); + } + + ScopedRNNBase rnn(seq_length, + batch_size, + input_size, + hidden_size, + num_layers, + dropout_prob, + seed, + weight_numel, + state_initialized, + is_bidirec); + rnn.Create(handle, + dev_ctx.GetPlace(), + SequenceLength, + &workspace_size, + &reserve_size, + state_out); + + phi::DenseTensor workspace_data_; + workspace_data_.Resize({static_cast(workspace_size)}); + dev_ctx.template Alloc(&workspace_data_); + + reserve->Resize({static_cast(reserve_size)}); + auto *reserve_data = dev_ctx.template Alloc(reserve); + + if (is_test) { + LSTMInference(has_seq_length, + handle, + seq_length, + &rnn, + x_data, + init_h_data, + init_c_data, + w_data, + out_data, + last_h_data, + last_c_data, + &workspace_data_, + workspace_size); + } else { +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnRNNForward(handle, + rnn.rnn_desc(), + CUDNN_FWD_MODE_TRAINING, + nullptr, + rnn.x_seq_desc(), + x_data, + rnn.y_seq_desc(), + out_data, + rnn.init_h_desc(), + init_h_data, + last_h_data, + rnn.init_c_desc(), + init_c_data, + last_c_data, + rnn.weights_size(), + w_data, + workspace_size, + workspace_data_.data(), + reserve_size, + reserve_data)); +#else + + if (!has_seq_length) { +// for train +// This interface is used when the input/output is unpadded. +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenRNNForwardTraining( + handle, + rnn.rnn_desc(), + seq_length, + rnn.x_descs(), + x_data, + rnn.init_h_desc(), + init_h_data, + rnn.init_c_desc(), + init_c_data, + rnn.weight_desc(), + w_data, + rnn.y_descs(), + out_data, + rnn.last_h_desc(), + last_h_data, + rnn.last_c_desc(), + last_c_data, + workspace_data_.data(), + workspace_size, + reserve_data, + reserve_size)); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnRNNForwardTraining(handle, + rnn.rnn_desc(), + seq_length, + rnn.x_descs(), + x_data, + rnn.init_h_desc(), + init_h_data, + rnn.init_c_desc(), + init_c_data, + rnn.weight_desc(), + w_data, + rnn.y_descs(), + out_data, + rnn.last_h_desc(), + last_h_data, + rnn.last_c_desc(), + last_c_data, + workspace_data_.data(), + workspace_size, + reserve_data, + reserve_size)); +#endif + } else { +#if !defined(PADDLE_WITH_HIP) && CUDNN_VERSION >= 7201 + // for train + // This interface is used when the input/output is padded. + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNForwardTrainingEx( + handle, + rnn.rnn_desc(), + rnn.x_seq_desc(), + x_data, + rnn.init_h_desc(), + init_h_data, + rnn.init_c_desc(), + init_c_data, + rnn.weight_desc(), + w_data, + rnn.y_seq_desc(), + out_data, + rnn.last_h_desc(), + last_h_data, + rnn.last_c_desc(), + last_c_data, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + workspace_data_.data(), + workspace_size, + reserve_data, + reserve_size)); +#else + PADDLE_THROW(common::errors::Unavailable( + "The padded input is supported by " + "cudnnRNNForwardTrainingEx, but it only works when " + "the version of cudnn is larger than 7.2.1")); +#endif + } +#endif // end CUDNN_VERSION >= 90000 + } +} + +} // namespace phi + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float) { + kernel->InputAt(5).SetDataType(phi::DataType::INT32); + kernel->OutputAt(3).SetDataType(phi::DataType::UINT8); + kernel->OutputAt(4).SetDataType(phi::DataType::UINT8); +} +#else +PD_REGISTER_PLUGIN_KERNEL( + cudnn_lstm, metax_gpu, ALL_LAYOUT, phi::CudnnLSTMKernel, float, double) { + kernel->InputAt(5).SetDataType(phi::DataType::INT32); + kernel->OutputAt(3).SetDataType(phi::DataType::UINT8); + kernel->OutputAt(4).SetDataType(phi::DataType::UINT8); +} +#endif diff --git a/backends/metax_gpu/tests/ignore.txt b/backends/metax_gpu/tests/ignore.txt index b4f1afbe5b0..4e54e17b3ef 100644 --- a/backends/metax_gpu/tests/ignore.txt +++ b/backends/metax_gpu/tests/ignore.txt @@ -19,3 +19,7 @@ test_uniform_random_op test_c_embedding_op test_slice_op test_compare_op +test_conv3d_transpose_op +test_conv3d_layer +test_conv3d_transpose_part2_op +test_fused_conv2d_add_act_op