2424
2525namespace phi {
2626
27- template <typename T>
28- __global__ void CrossGrad (const T* x,
29- const T* y,
30- const T* out,
31- T* out_dx,
32- T* out_dy,
33- const int64_t stride,
34- const int64_t N,
35- phi::funcs::IndexCalculator<int > index_calculator) {
36- CUDA_KERNEL_LOOP (i, N) {
37- int64_t offset = index_calculator (i);
38-
39- auto pos0 = offset + 0 * stride;
40- auto pos1 = offset + 1 * stride;
41- auto pos2 = offset + 2 * stride;
27+ template <typename T, typename IndexType>
28+ __global__ void CrossGrad (
29+ const T* x,
30+ const T* y,
31+ const T* out,
32+ T* out_dx,
33+ T* out_dy,
34+ const IndexType stride,
35+ const IndexType N,
36+ phi::funcs::IndexCalculator<IndexType> index_calculator) {
37+ CUDA_KERNEL_LOOP_TYPE (i, N, IndexType) {
38+ IndexType offset = index_calculator (i);
39+
40+ IndexType pos0 = offset + 0 * stride;
41+ IndexType pos1 = offset + 1 * stride;
42+ IndexType pos2 = offset + 2 * stride;
4243
4344 using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
4445
@@ -168,11 +169,10 @@ void CrossGradKernel(const Context& dev_ctx,
168169 const auto * input_out_grad_data = input_out_grad.data <T>();
169170 auto * output_x_grad_data = dev_ctx.template Alloc <T>(x_grad);
170171 auto * output_y_grad_data = dev_ctx.template Alloc <T>(y_grad);
171- auto index_calculator = phi::funcs::IndexCalculator<int >(
172- merged_dims.size () - 1 , cal_dims, left_strides, full_strides);
173172
174173 backends::gpu::GpuLaunchConfig config =
175174 backends::gpu::GetGpuLaunchConfig1D (dev_ctx, numel / 3 );
175+ constexpr int64_t int_max = std::numeric_limits<int >::max ();
176176 if (IsComplexType (x.dtype ())) {
177177 DenseTensor x_conj, y_conj;
178178 DenseTensorMeta meta_xy (x.dtype (), x.dims ());
@@ -189,30 +189,67 @@ void CrossGradKernel(const Context& dev_ctx,
189189 input_y_data, numel, input_y_conj_data);
190190 for_range (functor_x);
191191 for_range (functor_y);
192-
193- CrossGrad<<<config.block_per_grid,
194- config.thread_per_block,
195- 0 ,
196- dev_ctx.stream()>>> (input_x_conj_data,
197- input_y_conj_data,
198- input_out_grad_data,
199- output_x_grad_data,
200- output_y_grad_data,
201- full_strides[merge_axis],
202- numel / 3 ,
203- index_calculator);
192+ if (full_strides[merge_axis] * 2 > int_max || numel / 3 > int_max) {
193+ auto index_calculator = phi::funcs::IndexCalculator<int64_t >(
194+ merged_dims.size () - 1 , cal_dims, left_strides, full_strides);
195+ CrossGrad<<<config.block_per_grid,
196+ config.thread_per_block,
197+ 0 ,
198+ dev_ctx.stream()>>> (input_x_conj_data,
199+ input_y_conj_data,
200+ input_out_grad_data,
201+ output_x_grad_data,
202+ output_y_grad_data,
203+ full_strides[merge_axis],
204+ numel / 3 ,
205+ index_calculator);
206+ } else {
207+ auto index_calculator = phi::funcs::IndexCalculator<int32_t >(
208+ merged_dims.size () - 1 , cal_dims, left_strides, full_strides);
209+ CrossGrad<<<config.block_per_grid,
210+ config.thread_per_block,
211+ 0 ,
212+ dev_ctx.stream()>>> (
213+ input_x_conj_data,
214+ input_y_conj_data,
215+ input_out_grad_data,
216+ output_x_grad_data,
217+ output_y_grad_data,
218+ static_cast <int32_t >(full_strides[merge_axis]),
219+ static_cast <int32_t >(numel / 3 ),
220+ index_calculator);
221+ }
204222 } else {
205- CrossGrad<<<config.block_per_grid,
206- config.thread_per_block,
207- 0 ,
208- dev_ctx.stream()>>> (input_x_data,
209- input_y_data,
210- input_out_grad_data,
211- output_x_grad_data,
212- output_y_grad_data,
213- full_strides[merge_axis],
214- numel / 3 ,
215- index_calculator);
223+ if (full_strides[merge_axis] * 2 > int_max || numel / 3 > int_max) {
224+ auto index_calculator = phi::funcs::IndexCalculator<int64_t >(
225+ merged_dims.size () - 1 , cal_dims, left_strides, full_strides);
226+ CrossGrad<<<config.block_per_grid,
227+ config.thread_per_block,
228+ 0 ,
229+ dev_ctx.stream()>>> (input_x_data,
230+ input_y_data,
231+ input_out_grad_data,
232+ output_x_grad_data,
233+ output_y_grad_data,
234+ full_strides[merge_axis],
235+ numel / 3 ,
236+ index_calculator);
237+ } else {
238+ auto index_calculator = phi::funcs::IndexCalculator<int32_t >(
239+ merged_dims.size () - 1 , cal_dims, left_strides, full_strides);
240+ CrossGrad<<<config.block_per_grid,
241+ config.thread_per_block,
242+ 0 ,
243+ dev_ctx.stream()>>> (
244+ input_x_data,
245+ input_y_data,
246+ input_out_grad_data,
247+ output_x_grad_data,
248+ output_y_grad_data,
249+ static_cast <int32_t >(full_strides[merge_axis]),
250+ static_cast <int32_t >(numel / 3 ),
251+ index_calculator);
252+ }
216253 }
217254}
218255} // namespace phi
0 commit comments