@@ -19,6 +19,8 @@ namespace cuda {
19
19
20
20
/* !
21
21
* \brief CUDA kernel of segment reduce.
22
+ * \note each blockthread is responsible for aggregation on a row
23
+ * in the result tensor.
22
24
*/
23
25
template <typename IdType, typename DType,
24
26
typename ReduceOp>
@@ -41,7 +43,9 @@ __global__ void SegmentReduceKernel(
41
43
}
42
44
43
45
/* !
44
- * \brief CUDA kernel of segment reduce.
46
+ * \brief CUDA kernel of backward phase in segment min/max.
47
+ * \note each blockthread is responsible for writing a row in the
48
+ * result gradient tensor by lookup the ArgMin/Max for index information.
45
49
*/
46
50
template <typename IdType, typename DType>
47
51
__global__ void BackwardSegmentCmpKernel (
@@ -57,6 +61,13 @@ __global__ void BackwardSegmentCmpKernel(
57
61
}
58
62
}
59
63
64
+ /* !
65
+ * \brief CUDA implementation of forward phase of Segment Reduce.
66
+ * \param feat The input tensor.
67
+ * \param offsets The offsets tensor.
68
+ * \param out The output tensor.
69
+ * \param arg An auxiliary tensor storing ArgMax/Min information,
70
+ */
60
71
template <typename IdType, typename DType, typename ReduceOp>
61
72
void SegmentReduce (
62
73
NDArray feat,
@@ -80,12 +91,19 @@ void SegmentReduce(
80
91
const int nty = 1 ;
81
92
const dim3 nblks (nbx, nby);
82
93
const dim3 nthrs (ntx, nty);
94
+ // TODO(zihao): try cub's DeviceSegmentedReduce and compare the performance.
83
95
CUDA_KERNEL_CALL ((SegmentReduceKernel<IdType, DType, ReduceOp>),
84
96
nblks, nthrs, 0 , thr_entry->stream ,
85
97
feat_data, offsets_data, out_data, arg_data,
86
98
n, dim);
87
99
}
88
100
101
+ /* !
102
+ * \brief CUDA implementation of backward phase of Segment Reduce with Min/Max reducer.
103
+ * \param feat The input tensor.
104
+ * \param arg The ArgMin/Max information, used for indexing.
105
+ * \param out The output tensor.
106
+ */
89
107
template <typename IdType, typename DType>
90
108
void BackwardSegmentCmp (
91
109
NDArray feat,
0 commit comments