|
23 | 23 |
|
24 | 24 | namespace phi { |
25 | 25 |
|
26 | | -template <typename T> |
| 26 | +template <typename T, typename IndexT> |
27 | 27 | static __forceinline__ __device__ T Unnormalize(T coord, |
28 | | - int size, |
| 28 | + IndexT size, |
29 | 29 | bool align_corners) { |
30 | 30 | return align_corners ? ((coord + 1.f) / 2) * (size - 1) |
31 | 31 | : ((coord + 1.f) * size - 1) / 2; |
32 | 32 | } |
33 | 33 |
|
34 | | -template <typename T> |
35 | | -static __forceinline__ __device__ T ClipIndexes(T in, int max_value) { |
| 34 | +template <typename T, typename IndexT> |
| 35 | +static __forceinline__ __device__ T ClipIndexes(T in, IndexT max_value) { |
36 | 36 | return min(static_cast<T>(max_value - 1), max(in, static_cast<T>(0))); |
37 | 37 | } |
38 | 38 |
|
39 | | -template <typename T> |
| 39 | +template <typename T, typename IndexT> |
40 | 40 | static __forceinline__ __device__ T ReflectIndexes(T in, |
41 | | - int twice_low, |
42 | | - int twice_high) { |
| 41 | + IndexT twice_low, |
| 42 | + IndexT twice_high) { |
43 | 43 | if (twice_low == twice_high) { |
44 | 44 | return static_cast<T>(0); |
45 | 45 | } |
46 | 46 | T min = static_cast<T>(twice_low) / 2; |
47 | 47 | T span = static_cast<T>(twice_high - twice_low) / 2; |
48 | 48 | in = fabs(in - min); |
49 | 49 | T extra = fmod(in, span); |
50 | | - int flips = static_cast<int>(floor(in / span)); |
| 50 | + IndexT flips = floor(in / span); |
51 | 51 | return (flips & 1) ? span - extra + min : extra + min; // cond ? odd : even |
52 | 52 | } |
53 | 53 |
|
54 | | -template <typename T> |
| 54 | +template <typename T, typename IndexT> |
55 | 55 | static __forceinline__ __device__ T ComputePositions(T coord, |
56 | | - int size, |
| 56 | + IndexT size, |
57 | 57 | PaddingMode padding_mode, |
58 | 58 | bool align_corners) { |
59 | | - coord = Unnormalize<T>(coord, size, align_corners); |
| 59 | + coord = Unnormalize(coord, size, align_corners); |
60 | 60 | if (padding_mode == PaddingMode::border) { |
61 | 61 | coord = ClipIndexes(coord, size); |
62 | 62 | } else if (padding_mode == PaddingMode::reflect) { |
63 | | - coord = align_corners ? ReflectIndexes(coord, 0, 2 * (size - 1)) |
64 | | - : ReflectIndexes(coord, -1, 2 * size - 1); |
| 63 | + coord = align_corners ? ReflectIndexes<T, IndexT>(coord, 0, 2 * (size - 1)) |
| 64 | + : ReflectIndexes<T, IndexT>(coord, -1, 2 * size - 1); |
65 | 65 | coord = ClipIndexes(coord, size); |
66 | 66 | } |
67 | 67 | return SafeDownGradeToIntRange(coord); |
68 | 68 | } |
69 | 69 |
|
70 | | -template <typename T> |
71 | | -__global__ void GridSampleCudaKernel(const int nthreads, |
72 | | - int n, |
73 | | - int out_c, |
74 | | - int out_h, |
75 | | - int out_w, |
76 | | - int in_h, |
77 | | - int in_w, |
78 | | - const T* input, |
79 | | - const T* grid, |
80 | | - T* output, |
| 70 | +template <typename T, typename IndexT> |
| 71 | +__global__ void GridSampleCudaKernel(IndexT n, |
| 72 | + IndexT out_c, |
| 73 | + IndexT out_hw, |
| 74 | + IndexT in_h, |
| 75 | + IndexT in_w, |
| 76 | + const T* __restrict__ input, |
| 77 | + const T* __restrict__ grid, |
| 78 | + T* __restrict__ output, |
81 | 79 | const Mode mode, |
82 | 80 | const PaddingMode padding_mode, |
83 | 81 | bool align_corners) { |
84 | | - int inp_sN = out_c * in_h * in_w; |
85 | | - |
86 | | - int inp_sC = in_h * in_w; |
87 | | - int inp_sH = in_w; |
88 | | - int inp_sW = 1; |
89 | | - int grid_sN = out_h * out_w * 2; |
90 | | - int grid_sH = out_w * 2; |
91 | | - int grid_sW = 2; |
92 | | - int grid_sCoor = 1; |
93 | | - int out_sN = out_c * out_h * out_w; |
94 | | - int out_sC = out_h * out_w; |
95 | | - int out_sH = out_w; |
96 | | - int out_sW = 1; |
97 | | - CUDA_KERNEL_LOOP(index, nthreads) { |
98 | | - const int w = index % out_w; |
99 | | - const int h = (index / out_w) % out_h; |
100 | | - const int n = index / (out_h * out_w); |
101 | | - const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; |
| 82 | + IndexT nthreads = n * out_hw; |
| 83 | + IndexT inp_sN = out_c * (in_h * in_w); |
| 84 | + IndexT inp_sC = in_h * in_w; |
| 85 | + IndexT inp_sH = in_w; |
| 86 | + IndexT inp_sW = 1; |
| 87 | + IndexT grid_sNHW = 2; |
| 88 | + IndexT grid_sCoor = 1; |
| 89 | + IndexT out_sN = out_c * out_hw; |
| 90 | + IndexT out_sC = out_hw; |
| 91 | + IndexT out_sHW = 1; |
| 92 | + CUDA_KERNEL_LOOP_TYPE(index, nthreads, IndexT) { |
| 93 | + const IndexT hw = index % out_hw; |
| 94 | + const IndexT n = index / out_hw; |
| 95 | + const IndexT grid_offset = index * grid_sNHW; |
102 | 96 |
|
103 | 97 | T ix = grid[grid_offset]; |
104 | 98 | T iy = grid[grid_offset + grid_sCoor]; |
105 | 99 |
|
106 | 100 | ix = ComputePositions(ix, in_w, padding_mode, align_corners); |
107 | 101 | iy = ComputePositions(iy, in_h, padding_mode, align_corners); |
108 | 102 | if (mode == Mode::bilinear) { |
109 | | - int ix_nw = static_cast<int>(floor(ix)); |
110 | | - int iy_nw = static_cast<int>(floor(iy)); |
111 | | - int ix_ne = ix_nw + 1; |
112 | | - int iy_ne = iy_nw; |
113 | | - int ix_sw = ix_nw; |
114 | | - int iy_sw = iy_nw + 1; |
115 | | - int ix_se = ix_nw + 1; |
116 | | - int iy_se = iy_nw + 1; |
| 103 | + IndexT ix_nw = floor(ix); |
| 104 | + IndexT iy_nw = floor(iy); |
| 105 | + IndexT ix_ne = ix_nw + 1; |
| 106 | + IndexT iy_ne = iy_nw; |
| 107 | + IndexT ix_sw = ix_nw; |
| 108 | + IndexT iy_sw = iy_nw + 1; |
| 109 | + IndexT ix_se = ix_nw + 1; |
| 110 | + IndexT iy_se = iy_nw + 1; |
117 | 111 |
|
118 | 112 | T nw = (ix_se - ix) * (iy_se - iy); |
119 | 113 | T ne = (ix - ix_sw) * (iy_sw - iy); |
120 | 114 | T sw = (ix_ne - ix) * (iy - iy_ne); |
121 | 115 | T se = (ix - ix_nw) * (iy - iy_nw); |
122 | 116 |
|
123 | | - auto inp_offset_NC = n * inp_sN; |
| 117 | + IndexT inp_offset_NC = n * inp_sN; |
| 118 | + T* out_ptr_NCHW = output + (n * out_sN + hw * out_sHW); |
124 | 119 |
|
125 | | - auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; |
126 | | - for (int c = 0; c < out_c; |
| 120 | + for (IndexT c = 0; c < out_c; |
127 | 121 | ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { |
128 | | - *out_ptr_NCHW = static_cast<T>(0); |
| 122 | + T value{0}; |
129 | 123 | if (InBounds(iy_nw, ix_nw, in_h, in_w)) { |
130 | | - *out_ptr_NCHW += |
131 | | - input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw; |
| 124 | + value += input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw; |
132 | 125 | } |
133 | 126 | if (InBounds(iy_ne, ix_ne, in_h, in_w)) { |
134 | | - *out_ptr_NCHW += |
135 | | - input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne; |
| 127 | + value += input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne; |
136 | 128 | } |
137 | 129 | if (InBounds(iy_sw, ix_sw, in_h, in_w)) { |
138 | | - *out_ptr_NCHW += |
139 | | - input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw; |
| 130 | + value += input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw; |
140 | 131 | } |
141 | 132 | if (InBounds(iy_se, ix_se, in_h, in_w)) { |
142 | | - *out_ptr_NCHW += |
143 | | - input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se; |
| 133 | + value += input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se; |
144 | 134 | } |
| 135 | + *out_ptr_NCHW = value; |
145 | 136 | } |
146 | 137 | } else if (mode == Mode::nearest) { |
147 | | - int ix_nearest = static_cast<int>(std::nearbyint(ix)); |
148 | | - int iy_nearest = static_cast<int>(std::nearbyint(iy)); |
149 | | - auto inp_offset_NC = n * inp_sN; |
150 | | - auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; |
151 | | - for (int c = 0; c < out_c; |
| 138 | + IndexT ix_nearest = std::nearbyint(ix); |
| 139 | + IndexT iy_nearest = std::nearbyint(iy); |
| 140 | + IndexT inp_offset_NC = n * inp_sN; |
| 141 | + T* out_ptr_NCHW = output + (n * out_sN + hw * out_sHW); |
| 142 | + for (IndexT c = 0; c < out_c; |
152 | 143 | ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { |
153 | 144 | if (InBounds(iy_nearest, ix_nearest, in_h, in_w)) { |
154 | 145 | *out_ptr_NCHW = |
@@ -349,38 +340,54 @@ void GridSampleKernel(const Context& dev_ctx, |
349 | 340 | } |
350 | 341 |
|
351 | 342 | if (x.dims().size() == 4) { |
352 | | - const int n = grid.dims()[0]; |
353 | | - const int out_h = grid.dims()[1]; |
354 | | - const int out_w = grid.dims()[2]; |
355 | | - const int c = x.dims()[1]; |
356 | | - const int in_h = x.dims()[2]; |
357 | | - const int in_w = x.dims()[3]; |
| 343 | + const int64_t n = grid.dims()[0]; |
| 344 | + const int64_t out_h = grid.dims()[1]; |
| 345 | + const int64_t out_w = grid.dims()[2]; |
| 346 | + const int64_t c = x.dims()[1]; |
| 347 | + const int64_t in_h = x.dims()[2]; |
| 348 | + const int64_t in_w = x.dims()[3]; |
358 | 349 | VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h |
359 | 350 | << "; out_w: " << out_w; |
360 | 351 |
|
361 | 352 | auto* output_data = dev_ctx.template Alloc<T>(out); |
362 | 353 | VLOG(3) << "out dims: " << out->dims()[0] << "; " << out->dims()[1] << "; " |
363 | 354 | << out->dims()[2] << "; " << out->dims()[3]; |
364 | 355 |
|
365 | | - int count = static_cast<int>(n * out_h * out_w); |
| 356 | + int64_t count = n * out_h * out_w; |
366 | 357 | auto cu_stream = dev_ctx.stream(); |
367 | 358 | backends::gpu::GpuLaunchConfig config = |
368 | 359 | backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count); |
369 | | - GridSampleCudaKernel<T> |
370 | | - <<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>( |
371 | | - count, |
372 | | - n, |
373 | | - c, |
374 | | - out_h, |
375 | | - out_w, |
376 | | - in_h, |
377 | | - in_w, |
378 | | - x.data<T>(), |
379 | | - grid.data<T>(), |
380 | | - output_data, |
381 | | - enum_mode, |
382 | | - enum_padding_mode, |
383 | | - align_corners); |
| 360 | + if (x.numel() <= std::numeric_limits<int>::max() && |
| 361 | + grid.numel() <= std::numeric_limits<int>::max() && |
| 362 | + out->numel() <= std::numeric_limits<int>::max()) { |
| 363 | + GridSampleCudaKernel<T, int> |
| 364 | + <<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>( |
| 365 | + n, |
| 366 | + c, |
| 367 | + out_h * out_w, |
| 368 | + in_h, |
| 369 | + in_w, |
| 370 | + x.data<T>(), |
| 371 | + grid.data<T>(), |
| 372 | + output_data, |
| 373 | + enum_mode, |
| 374 | + enum_padding_mode, |
| 375 | + align_corners); |
| 376 | + } else { |
| 377 | + GridSampleCudaKernel<T, int64_t> |
| 378 | + <<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>( |
| 379 | + n, |
| 380 | + c, |
| 381 | + out_h * out_w, |
| 382 | + in_h, |
| 383 | + in_w, |
| 384 | + x.data<T>(), |
| 385 | + grid.data<T>(), |
| 386 | + output_data, |
| 387 | + enum_mode, |
| 388 | + enum_padding_mode, |
| 389 | + align_corners); |
| 390 | + } |
384 | 391 | } else { |
385 | 392 | const int n = grid.dims()[0]; |
386 | 393 | const int out_d = grid.dims()[1]; |
|
0 commit comments