| 
 | 1 | +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.  | 
 | 2 | +
  | 
 | 3 | +Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 4 | +you may not use this file except in compliance with the License.  | 
 | 5 | +You may obtain a copy of the License at  | 
 | 6 | +
  | 
 | 7 | +    http://www.apache.org/licenses/LICENSE-2.0  | 
 | 8 | +
  | 
 | 9 | +Unless required by applicable law or agreed to in writing, software  | 
 | 10 | +distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 12 | +See the License for the specific language governing permissions and  | 
 | 13 | +limitations under the License. */  | 
 | 14 | + | 
 | 15 | +#include "paddle/phi/kernels/cross_entropy_grad_kernel.h"  | 
 | 16 | + | 
 | 17 | +#ifdef __NVCC__  | 
 | 18 | +#include "cub/cub.cuh"  | 
 | 19 | +#endif  | 
 | 20 | +#ifdef __HIPCC__  | 
 | 21 | +#include <hipcub/hipcub.hpp>  | 
 | 22 | +namespace cub = hipcub;  | 
 | 23 | +#endif  | 
 | 24 | + | 
 | 25 | +#include "kernels/gpudnn/softmax_gpudnn.h"  | 
 | 26 | +#include "paddle/phi/backends/gpu/gpu_device_function.h"  | 
 | 27 | +#include "paddle/phi/backends/gpu/gpu_dnn.h"  | 
 | 28 | +#include "paddle/phi/common/amp_type_traits.h"  | 
 | 29 | +#include "paddle/phi/core/kernel_registry.h"  | 
 | 30 | +#include "paddle/phi/core/tensor_utils.h"  | 
 | 31 | +#include "paddle/phi/core/visit_type.h"  | 
 | 32 | +#include "paddle/phi/kernels/funcs/axis_utils.h"  | 
 | 33 | +#include "paddle/phi/kernels/funcs/for_range.h"  | 
 | 34 | +#include "paddle/phi/kernels/funcs/math_function.h"  | 
 | 35 | +#include "paddle/phi/kernels/funcs/softmax.h"  | 
 | 36 | + | 
 | 37 | +namespace phi {  | 
 | 38 | + | 
 | 39 | +/*  | 
 | 40 | +  Vectorized wrapper of softmax with cross entropy grad hard label.  | 
 | 41 | +  Optimized with float4 vectorization for memory coalescing and improved  | 
 | 42 | +  throughput.  | 
 | 43 | +*/  | 
 | 44 | +template <typename T, typename LabelT, typename LogitT>  | 
 | 45 | +__global__ void SoftmaxWithCrossEntropyGradHardLabelVectorized(  | 
 | 46 | +    LogitT* __restrict__ logits_grad,  | 
 | 47 | +    const T* __restrict__ loss_grad,  | 
 | 48 | +    const T* __restrict__ softmax,  | 
 | 49 | +    const LabelT* __restrict__ labels,  | 
 | 50 | +    const int64_t n,  | 
 | 51 | +    const int64_t dim,  | 
 | 52 | +    const int64_t d,  | 
 | 53 | +    const int ignore_index) {  | 
 | 54 | +  // Vectorized load/store with float4 for 128-bit memory transactions  | 
 | 55 | +  constexpr int VEC_SIZE = 4;  | 
 | 56 | +  using VecT = typename phi::AlignedVector<LogitT, VEC_SIZE>;  | 
 | 57 | +  using SoftmaxVecT = typename phi::AlignedVector<T, VEC_SIZE>;  | 
 | 58 | + | 
 | 59 | +  int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;  | 
 | 60 | +  int64_t vec_id = tid * VEC_SIZE;  | 
 | 61 | + | 
 | 62 | +  // Ensure we don't exceed bounds  | 
 | 63 | +  if (vec_id >= n * dim * d) return;  | 
 | 64 | + | 
 | 65 | +  // Compute indices for vectorized access  | 
 | 66 | +  int64_t idx_n = vec_id / (d * dim);  | 
 | 67 | +  int64_t idx_dim_start = (vec_id / d) % dim;  | 
 | 68 | +  int64_t idx_d = vec_id % d;  | 
 | 69 | +  int64_t ids = idx_n * d + idx_d;  | 
 | 70 | + | 
 | 71 | +  // Load label once per thread  | 
 | 72 | +  auto lbl = static_cast<int64_t>(labels[ids]);  | 
 | 73 | + | 
 | 74 | +  if (lbl == ignore_index) {  | 
 | 75 | +    // Vectorized zero fill for ignore_index  | 
 | 76 | +    VecT* vec_grad = reinterpret_cast<VecT*>(&logits_grad[vec_id]);  | 
 | 77 | +    VecT zero_vec;  | 
 | 78 | +#pragma unroll  | 
 | 79 | +    for (int i = 0; i < VEC_SIZE; ++i) {  | 
 | 80 | +      zero_vec.val[i] = static_cast<LogitT>(0.0f);  | 
 | 81 | +    }  | 
 | 82 | +    *vec_grad = zero_vec;  | 
 | 83 | +    return;  | 
 | 84 | +  }  | 
 | 85 | + | 
 | 86 | +  // Vectorized load of softmax values  | 
 | 87 | +  SoftmaxVecT softmax_vec;  | 
 | 88 | +  const SoftmaxVecT* softmax_ptr =  | 
 | 89 | +      reinterpret_cast<const SoftmaxVecT*>(&softmax[vec_id]);  | 
 | 90 | +  softmax_vec = *softmax_ptr;  | 
 | 91 | + | 
 | 92 | +  // Load loss gradient (broadcast across vector elements)  | 
 | 93 | +  T loss_grad_val = loss_grad[ids];  | 
 | 94 | + | 
 | 95 | +  // Vectorized computation  | 
 | 96 | +  VecT grad_vec;  | 
 | 97 | +#pragma unroll  | 
 | 98 | +  for (int i = 0; i < VEC_SIZE; ++i) {  | 
 | 99 | +    int64_t current_dim = idx_dim_start + i;  | 
 | 100 | +    if (current_dim < dim) {  // Bounds check for partial vectors  | 
 | 101 | +      float softmax_val = static_cast<float>(softmax_vec.val[i]);  | 
 | 102 | +      float grad_val;  | 
 | 103 | + | 
 | 104 | +      if (lbl == current_dim) {  | 
 | 105 | +        grad_val = (softmax_val - 1.0f) * static_cast<float>(loss_grad_val);  | 
 | 106 | +      } else {  | 
 | 107 | +        grad_val = softmax_val * static_cast<float>(loss_grad_val);  | 
 | 108 | +      }  | 
 | 109 | + | 
 | 110 | +      grad_vec.val[i] = static_cast<LogitT>(grad_val);  | 
 | 111 | +    } else {  | 
 | 112 | +      grad_vec.val[i] = static_cast<LogitT>(0.0f);  | 
 | 113 | +    }  | 
 | 114 | +  }  | 
 | 115 | + | 
 | 116 | +  // Vectorized store  | 
 | 117 | +  VecT* grad_ptr = reinterpret_cast<VecT*>(&logits_grad[vec_id]);  | 
 | 118 | +  *grad_ptr = grad_vec;  | 
 | 119 | +}  | 
 | 120 | + | 
 | 121 | +/*  | 
 | 122 | +  Specialized kernel for dimensions not divisible by vector size  | 
 | 123 | +  Uses warp-level primitives for better performance on irregular sizes  | 
 | 124 | +*/  | 
 | 125 | +template <typename T, typename LabelT, typename LogitT>  | 
 | 126 | +__global__ void SoftmaxWithCrossEntropyGradHardLabelWarp(  | 
 | 127 | +    LogitT* __restrict__ logits_grad,  | 
 | 128 | +    const T* __restrict__ loss_grad,  | 
 | 129 | +    const T* __restrict__ softmax,  | 
 | 130 | +    const LabelT* __restrict__ labels,  | 
 | 131 | +    const int64_t n,  | 
 | 132 | +    const int64_t dim,  | 
 | 133 | +    const int64_t d,  | 
 | 134 | +    const int ignore_index) {  | 
 | 135 | +  const int warps_per_block = 4;  | 
 | 136 | +  const int threads_per_warp = 32;  | 
 | 137 | +  const int threads_per_block = warps_per_block * threads_per_warp;  | 
 | 138 | + | 
 | 139 | +  int tid = blockIdx.x * threads_per_block + threadIdx.x;  | 
 | 140 | +  int warp_id = threadIdx.x / threads_per_warp;  | 
 | 141 | +  int lane_id = threadIdx.x % threads_per_warp;  | 
 | 142 | + | 
 | 143 | +  // Process multiple elements per thread using warp-level parallelism  | 
 | 144 | +  int64_t elements_per_thread =  | 
 | 145 | +      (n * dim * d + gridDim.x * threads_per_block - 1) /  | 
 | 146 | +      (gridDim.x * threads_per_block);  | 
 | 147 | + | 
 | 148 | +  for (int e = 0; e < elements_per_thread; ++e) {  | 
 | 149 | +    int64_t idx = tid + e * gridDim.x * threads_per_block;  | 
 | 150 | +    if (idx >= n * dim * d) break;  | 
 | 151 | + | 
 | 152 | +    int64_t idx_n = idx / (d * dim);  | 
 | 153 | +    int64_t idx_dim = (idx / d) % dim;  | 
 | 154 | +    int64_t idx_d = idx % d;  | 
 | 155 | +    int64_t ids = idx_n * d + idx_d;  | 
 | 156 | + | 
 | 157 | +    auto lbl = static_cast<int64_t>(labels[ids]);  | 
 | 158 | + | 
 | 159 | +    if (lbl == ignore_index) {  | 
 | 160 | +      logits_grad[idx] = static_cast<LogitT>(0.0f);  | 
 | 161 | +    } else if (lbl == idx_dim) {  | 
 | 162 | +      logits_grad[idx] =  | 
 | 163 | +          static_cast<LogitT>((static_cast<float>(softmax[idx]) - 1.0f) *  | 
 | 164 | +                              static_cast<float>(loss_grad[ids]));  | 
 | 165 | +    } else {  | 
 | 166 | +      logits_grad[idx] =  | 
 | 167 | +          static_cast<LogitT>(static_cast<float>(softmax[idx]) *  | 
 | 168 | +                              static_cast<float>(loss_grad[ids]));  | 
 | 169 | +    }  | 
 | 170 | +  }  | 
 | 171 | +}  | 
 | 172 | + | 
 | 173 | +/*  | 
 | 174 | +  Optimized kernel selector based on problem size and alignment  | 
 | 175 | +*/  | 
 | 176 | +template <typename T, typename LabelT, typename LogitT>  | 
 | 177 | +void LaunchOptimizedCrossEntropyGradKernel(const GPUContext& dev_ctx,  | 
 | 178 | +                                           LogitT* logits_grad,  | 
 | 179 | +                                           const T* loss_grad,  | 
 | 180 | +                                           const T* softmax,  | 
 | 181 | +                                           const LabelT* labels,  | 
 | 182 | +                                           const int64_t n,  | 
 | 183 | +                                           const int64_t dim,  | 
 | 184 | +                                           const int64_t d,  | 
 | 185 | +                                           const int ignore_index) {  | 
 | 186 | +  const int64_t total_elements = n * dim * d;  | 
 | 187 | +  auto stream = dev_ctx.stream();  | 
 | 188 | + | 
 | 189 | +  // Check alignment for vectorized kernel  | 
 | 190 | +  bool is_aligned = (reinterpret_cast<uintptr_t>(logits_grad) % 16 == 0) &&  | 
 | 191 | +                    (reinterpret_cast<uintptr_t>(softmax) % 16 == 0) &&  | 
 | 192 | +                    (total_elements % 4 == 0);  | 
 | 193 | + | 
 | 194 | +  if (is_aligned && total_elements >= 1024) {  | 
 | 195 | +    // Use vectorized kernel for aligned, large problems  | 
 | 196 | +    constexpr int VEC_SIZE = 4;  | 
 | 197 | +    const int threads_per_block = 256;  | 
 | 198 | +    const int vec_elements = total_elements / VEC_SIZE;  | 
 | 199 | +    const int blocks =  | 
 | 200 | +        (vec_elements + threads_per_block - 1) / threads_per_block;  | 
 | 201 | + | 
 | 202 | +    SoftmaxWithCrossEntropyGradHardLabelVectorized<T, LabelT, LogitT>  | 
 | 203 | +        <<<blocks, threads_per_block, 0, stream>>>(  | 
 | 204 | +            logits_grad, loss_grad, softmax, labels, n, dim, d, ignore_index);  | 
 | 205 | +  } else {  | 
 | 206 | +    // Use warp-specialized kernel for irregular sizes  | 
 | 207 | +    const int warps_per_block = 4;  | 
 | 208 | +    const int threads_per_block = warps_per_block * 32;  | 
 | 209 | +    const int blocks =  | 
 | 210 | +        std::min(1024,  | 
 | 211 | +                 static_cast<int>((total_elements + threads_per_block - 1) /  | 
 | 212 | +                                  threads_per_block));  | 
 | 213 | + | 
 | 214 | +    SoftmaxWithCrossEntropyGradHardLabelWarp<T, LabelT, LogitT>  | 
 | 215 | +        <<<blocks, threads_per_block, 0, stream>>>(  | 
 | 216 | +            logits_grad, loss_grad, softmax, labels, n, dim, d, ignore_index);  | 
 | 217 | +  }  | 
 | 218 | +}  | 
 | 219 | + | 
 | 220 | +template <typename T, typename LabelT>  | 
 | 221 | +void CrossEntropyWithSoftmaxBwdWithDowncastGPUKernel(  | 
 | 222 | +    const GPUContext& dev_ctx,  | 
 | 223 | +    const DenseTensor& label,  | 
 | 224 | +    const DenseTensor& softmax,  | 
 | 225 | +    const DenseTensor& loss_grad,  | 
 | 226 | +    int axis,  | 
 | 227 | +    DenseTensor* logits_grad) {  | 
 | 228 | +  //   PADDLE_ENFORCE_EQ(  | 
 | 229 | +  //       dev_ctx.GetPlace().GetType(),  | 
 | 230 | +  //       phi::AllocationType::GPU,  | 
 | 231 | +  //       common::errors::Unavailable("softmax_with_cross_entropy operator's "  | 
 | 232 | +  //                                   "CUDA kernel only runs on GPU device."));  | 
 | 233 | + | 
 | 234 | +  using LogitT = phi::bfloat16;  | 
 | 235 | +  const T* loss_grad_data = loss_grad.data<T>();  | 
 | 236 | +  DenseTensor* logit_grad = logits_grad;  | 
 | 237 | + | 
 | 238 | +  LogitT* logit_grad_data = nullptr;  | 
 | 239 | +  logit_grad_data = dev_ctx.template Alloc<LogitT>(logit_grad);  | 
 | 240 | + | 
 | 241 | +  const int rank = logit_grad->dims().size();  | 
 | 242 | +  const int axis_v = phi::funcs::CanonicalAxis(axis, rank);  | 
 | 243 | +  int axis_dim = logit_grad->dims()[axis_v];  | 
 | 244 | + | 
 | 245 | +  const int64_t n = phi::funcs::SizeToAxis(axis_v, logit_grad->dims());  | 
 | 246 | +  const int64_t d = phi::funcs::SizeFromAxis(axis_v, logit_grad->dims());  | 
 | 247 | +  const int64_t remain = d / axis_dim;  | 
 | 248 | + | 
 | 249 | +  const T* softmax_data = softmax.data<T>();  | 
 | 250 | +  const auto* label_data = label.data<LabelT>();  | 
 | 251 | + | 
 | 252 | +  // Launch optimized kernel with automatic selection  | 
 | 253 | +  LaunchOptimizedCrossEntropyGradKernel<T, LabelT, LogitT>(dev_ctx,  | 
 | 254 | +                                                           logit_grad_data,  | 
 | 255 | +                                                           loss_grad_data,  | 
 | 256 | +                                                           softmax_data,  | 
 | 257 | +                                                           label_data,  | 
 | 258 | +                                                           n,  | 
 | 259 | +                                                           axis_dim,  | 
 | 260 | +                                                           remain,  | 
 | 261 | +                                                           -100);  | 
 | 262 | +}  | 
 | 263 | + | 
 | 264 | +template <typename T, typename Context>  | 
 | 265 | +void CrossEntropyWithSoftmaxBwdWithDowncastKernel(const Context& dev_ctx,  | 
 | 266 | +                                                  const DenseTensor& label,  | 
 | 267 | +                                                  const DenseTensor& softmax,  | 
 | 268 | +                                                  const DenseTensor& loss_grad,  | 
 | 269 | +                                                  DenseTensor* logits_grad) {  | 
 | 270 | +  constexpr int axis = -1;  | 
 | 271 | +  if (logits_grad->numel() == 0) {  | 
 | 272 | +    dev_ctx.template Alloc<phi::bfloat16>(logits_grad);  | 
 | 273 | +    return;  | 
 | 274 | +  }  | 
 | 275 | +  auto dtype = label.dtype();  | 
 | 276 | +  PD_VISIT_INTEGRAL_TYPES(  | 
 | 277 | +      dtype, "CrossEntropyWithSoftmaxBwdWithDowncastGPUKernel", ([&] {  | 
 | 278 | +        CrossEntropyWithSoftmaxBwdWithDowncastGPUKernel<T, data_t>(  | 
 | 279 | +            dev_ctx, label, softmax, loss_grad, axis, logits_grad);  | 
 | 280 | +      }));  | 
 | 281 | +}  | 
 | 282 | + | 
 | 283 | +}  // namespace phi  | 
 | 284 | + | 
 | 285 | +PD_REGISTER_PLUGIN_KERNEL(cross_entropy_with_softmax_bwd_w_downcast,  | 
 | 286 | +                          metax_gpu,  | 
 | 287 | +                          ALL_LAYOUT,  | 
 | 288 | +                          phi::CrossEntropyWithSoftmaxBwdWithDowncastKernel,  | 
 | 289 | +                          float,  | 
 | 290 | +                          double,  | 
 | 291 | +                          phi::float16) {}  | 
0 commit comments