Skip to content

Commit 6b02bab

Browse files
authored
[doc] Add docstring for segment reduce. (dmlc#2375)
1 parent 35a3ead commit 6b02bab

File tree

8 files changed

+81
-27
lines changed

8 files changed

+81
-27
lines changed

docs/source/api/python/dgl.ops.rst

+11-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ Like GSpMM, GSDDMM operators support both homogeneous and bipartite graph.
239239
Edge Softmax module
240240
-------------------
241241

242-
We also provide framework agnostic edge softmax module which was frequently used in
242+
DGL also provide framework agnostic edge softmax module which was frequently used in
243243
GNN-like structures, e.g.
244244
`Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`_,
245245
`Transformer <https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf>`_,
@@ -250,6 +250,16 @@ GNN-like structures, e.g.
250250

251251
edge_softmax
252252

253+
Segment Reduce Module
254+
---------------------
255+
256+
DGL provide operators to reduce value tensor along the first dimension by segments.
257+
258+
.. autosummary::
259+
:toctree: ../../generated/
260+
261+
segment_reduce
262+
253263
Relation with Message Passing APIs
254264
----------------------------------
255265

python/dgl/backend/backend.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -1512,23 +1512,27 @@ def segment_reduce(op, x, offsets):
15121512
"""Segment reduction operator.
15131513
15141514
It aggregates the value tensor along the first dimension by segments.
1515-
The first argument ``seglen`` stores the length of each segment. Its
1516-
summation must be equal to the first dimension of the ``value`` tensor.
1517-
Zero-length segments are allowed.
1515+
The argument ``offsets`` specifies the start offset of each segment (and
1516+
the upper bound of the last segment). Zero-length segments are allowed.
1517+
1518+
.. math::
1519+
y_i = \Phi_{j=\mathrm{offsets}_i}^{\mathrm{offsets}_{i+1}-1} x_j
1520+
1521+
where :math:`\Phi` is the reduce operator.
15181522
15191523
Parameters
15201524
----------
15211525
op : str
1522-
Aggregation method. Can be 'sum', 'max', 'min'.
1523-
seglen : Tensor
1524-
Segment lengths.
1525-
value : Tensor
1526+
Aggregation method. Can be ``sum``, ``max``, ``min``.
1527+
x : Tensor
15261528
Value to aggregate.
1529+
offsets : Tensor
1530+
The start offsets of segments.
15271531
15281532
Returns
15291533
-------
15301534
Tensor
1531-
Aggregated tensor of shape ``(len(seglen), value.shape[1:])``.
1535+
Aggregated tensor of shape ``(len(offsets) - 1, value.shape[1:])``.
15321536
"""
15331537
pass
15341538

python/dgl/ops/segment.py

-2
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ def segment_softmax(seglen, value):
6969
Segment lengths.
7070
value : Tensor
7171
Value to aggregate.
72-
reducer : str, optional
73-
Aggregation method. Can be 'sum', 'max', 'min', 'mean'.
7472
7573
Returns
7674
-------

python/dgl/sparse.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -252,18 +252,22 @@ def _segment_reduce(op, feat, offsets):
252252
r"""Segment reduction operator.
253253
254254
It aggregates the value tensor along the first dimension by segments.
255-
The first argument ``seglen`` stores the length of each segment. Its
256-
summation must be equal to the first dimension of the ``value`` tensor.
257-
Zero-length segments are allowed.
255+
The argument ``offsets`` specifies the start offset of each segment (and
256+
the upper bound of the last segment). Zero-length segments are allowed.
257+
258+
.. math::
259+
y_i = \Phi_{j=\mathrm{offsets}_i}^{\mathrm{offsets}_{i+1}-1} x_j
260+
261+
where :math:`\Phi` is the reduce operator.
258262
259263
Parameters
260264
----------
261265
op : str
262-
Aggregation method. Can be 'sum', 'max', 'min'.
263-
seglen : Tensor
264-
Segment lengths.
265-
value : Tensor
266+
Aggregation method. Can be ``sum``, ``max``, ``min``.
267+
x : Tensor
266268
Value to aggregate.
269+
offsets : Tensor
270+
The start offsets of segments.
267271
268272
Returns
269273
-------

src/array/cpu/segment_reduce.h

+20
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ namespace dgl {
1212
namespace aten {
1313
namespace cpu {
1414

15+
/*!
16+
* \brief CPU kernel of segment sum.
17+
* \param feat The input tensor.
18+
* \param offsets The offset tensor storing the ranges of segments.
19+
* \param out The output tensor.
20+
*/
1521
template <typename IdType, typename DType>
1622
void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
1723
int n = out->shape[0];
@@ -31,6 +37,14 @@ void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
3137
}
3238
}
3339

40+
/*!
41+
* \brief CPU kernel of segment min/max.
42+
* \param feat The input tensor.
43+
* \param offsets The offset tensor storing the ranges of segments.
44+
* \param out The output tensor.
45+
* \param arg An auxiliary tensor storing the argmin/max information
46+
* used in backward phase.
47+
*/
3448
template <typename IdType, typename DType, typename Cmp>
3549
void SegmentCmp(NDArray feat, NDArray offsets,
3650
NDArray out, NDArray arg) {
@@ -58,6 +72,12 @@ void SegmentCmp(NDArray feat, NDArray offsets,
5872
}
5973
}
6074

75+
/*!
76+
* \brief CPU kernel of backward phase of segment min/max.
77+
* \param feat The input tensor.
78+
* \param arg The argmin/argmax tensor.
79+
* \param out The output tensor.
80+
*/
6181
template <typename IdType, typename DType>
6282
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
6383
int n = feat->shape[0];

src/array/cuda/sddmm.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ __device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx
146146
* is responsible for the computation on different edges. Threadblocks
147147
* on the x-axis are responsible for the computation on different positions
148148
* in feature dimension.
149-
* To efficiently find the source node idx and destination node index of an
149+
* To efficiently find the source node idx and destination node index of an
150150
* given edge on Csr format, it uses binary search (time complexity O(log N)).
151151
*/
152152
template <typename Idx, typename DType, typename BinaryOp,
@@ -239,7 +239,7 @@ void SDDMMCoo(
239239
coo.num_rows, coo.num_cols, nnz, reduce_dim,
240240
lhs_off, rhs_off,
241241
lhs_len, rhs_len, len);
242-
});
242+
});
243243
} else {
244244
const int ntx = FindNumThreads(len);
245245
const int nty = CUDA_MAX_NUM_THREADS / ntx;

src/array/cuda/segment_reduce.cuh

+19-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ namespace cuda {
1919

2020
/*!
2121
* \brief CUDA kernel of segment reduce.
22+
* \note each blockthread is responsible for aggregation on a row
23+
* in the result tensor.
2224
*/
2325
template <typename IdType, typename DType,
2426
typename ReduceOp>
@@ -41,7 +43,9 @@ __global__ void SegmentReduceKernel(
4143
}
4244

4345
/*!
44-
* \brief CUDA kernel of segment reduce.
46+
* \brief CUDA kernel of backward phase in segment min/max.
47+
* \note each blockthread is responsible for writing a row in the
48+
* result gradient tensor by lookup the ArgMin/Max for index information.
4549
*/
4650
template <typename IdType, typename DType>
4751
__global__ void BackwardSegmentCmpKernel(
@@ -57,6 +61,13 @@ __global__ void BackwardSegmentCmpKernel(
5761
}
5862
}
5963

64+
/*!
65+
* \brief CUDA implementation of forward phase of Segment Reduce.
66+
* \param feat The input tensor.
67+
* \param offsets The offsets tensor.
68+
* \param out The output tensor.
69+
* \param arg An auxiliary tensor storing ArgMax/Min information,
70+
*/
6071
template <typename IdType, typename DType, typename ReduceOp>
6172
void SegmentReduce(
6273
NDArray feat,
@@ -80,12 +91,19 @@ void SegmentReduce(
8091
const int nty = 1;
8192
const dim3 nblks(nbx, nby);
8293
const dim3 nthrs(ntx, nty);
94+
// TODO(zihao): try cub's DeviceSegmentedReduce and compare the performance.
8395
CUDA_KERNEL_CALL((SegmentReduceKernel<IdType, DType, ReduceOp>),
8496
nblks, nthrs, 0, thr_entry->stream,
8597
feat_data, offsets_data, out_data, arg_data,
8698
n, dim);
8799
}
88100

101+
/*!
102+
* \brief CUDA implementation of backward phase of Segment Reduce with Min/Max reducer.
103+
* \param feat The input tensor.
104+
* \param arg The ArgMin/Max information, used for indexing.
105+
* \param out The output tensor.
106+
*/
89107
template <typename IdType, typename DType>
90108
void BackwardSegmentCmp(
91109
NDArray feat,

src/array/cuda/spmm.cuh

+6-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using namespace cuda;
1919
namespace aten {
2020
namespace cuda {
2121

22-
/*!
22+
/*!
2323
* \brief CUDA Kernel of filling the vector started from ptr of size length
2424
* with val.
2525
* \note internal use only.
@@ -134,7 +134,7 @@ __global__ void ArgSpMMCooKernel(
134134
/*!
135135
* \brief CUDA kernel of g-SpMM on Coo format.
136136
* \note it uses node parallel strategy, different threadblocks (on y-axis)
137-
* is responsible for the computation on different destination nodes.
137+
* is responsible for the computation on different destination nodes.
138138
* Threadblocks on the x-axis are responsible for the computation on
139139
* different positions in feature dimension.
140140
*/
@@ -191,10 +191,10 @@ __global__ void SpMMCsrKernel(
191191
* \param ufeat The feature on source nodes.
192192
* \param efeat The feature on edges.
193193
* \param out The result feature on destination nodes.
194-
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
194+
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
195195
* correspond to the minimum/maximum values of reduction result on
196196
* destination nodes. It's useful in computing gradients of Min/Max reducer.
197-
* \param arge Arg-Min/Max on edges. which refers the source node indices
197+
* \param arge Arg-Min/Max on edges. which refers the source node indices
198198
* correspond to the minimum/maximum values of reduction result on
199199
* destination nodes. It's useful in computing gradients of Min/Max reducer.
200200
*/
@@ -263,10 +263,10 @@ void SpMMCoo(
263263
* \param ufeat The feature on source nodes.
264264
* \param efeat The feature on edges.
265265
* \param out The result feature on destination nodes.
266-
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
266+
* \param argu Arg-Min/Max on source nodes, which refers the source node indices
267267
* correspond to the minimum/maximum values of reduction result on
268268
* destination nodes. It's useful in computing gradients of Min/Max reducer.
269-
* \param arge Arg-Min/Max on edges. which refers the source node indices
269+
* \param arge Arg-Min/Max on edges. which refers the source node indices
270270
* correspond to the minimum/maximum values of reduction result on
271271
* destination nodes. It's useful in computing gradients of Min/Max reducer.
272272
*/

0 commit comments

Comments
 (0)