Skip to content

Commit 3adbfa1

Browse files
authored
[Performance] Use segment operators for graph readout. (dmlc#2361)
* upd * upd * update * upd * upd * upd * fix * lint * lint * pylint * doc
1 parent 45e3e9a commit 3adbfa1

18 files changed

+750
-26
lines changed

python/dgl/backend/backend.py

+42
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,23 @@ def reduce_sum(input):
370370
"""
371371
pass
372372

373+
def cumsum(input, dim):
374+
"""Return the cumulative sum of the elements along a given axis.
375+
376+
Parameters
377+
----------
378+
input : Tensor
379+
The input tensor.
380+
dim : int
381+
The cumulative dimension.
382+
383+
Returns
384+
-------
385+
Tensor
386+
A framework-specific tensor.
387+
"""
388+
pass
389+
373390
def mean(input, dim):
374391
"""Reduce average the input tensor along the given dim.
375392
@@ -1489,6 +1506,31 @@ def edge_softmax(gidx, logits, eids, norm_by):
14891506
Tensor
14901507
Softmax value
14911508
"""
1509+
pass
1510+
1511+
def segment_reduce(op, x, offsets):
1512+
"""Segment reduction operator.
1513+
1514+
It aggregates the value tensor along the first dimension by segments.
1515+
The first argument ``seglen`` stores the length of each segment. Its
1516+
summation must be equal to the first dimension of the ``value`` tensor.
1517+
Zero-length segments are allowed.
1518+
1519+
Parameters
1520+
----------
1521+
op : str
1522+
Aggregation method. Can be 'sum', 'max', 'min'.
1523+
seglen : Tensor
1524+
Segment lengths.
1525+
value : Tensor
1526+
Value to aggregate.
1527+
1528+
Returns
1529+
-------
1530+
Tensor
1531+
Aggregated tensor of shape ``(len(seglen), value.shape[1:])``.
1532+
"""
1533+
pass
14921534

14931535

14941536
###############################################################################

python/dgl/backend/mxnet/sparse.py

+35-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import mxnet as mx
22
import numpy as np
33
from mxnet import nd
4-
from ...sparse import _gspmm, _gsddmm
4+
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
55
from ...base import dgl_warning, is_all, ALL
66
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx
77

8-
__all__ = ['gspmm', 'gsddmm', 'edge_softmax']
8+
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
99

1010

1111
def _scatter_nd(index, src, n_rows):
@@ -28,7 +28,7 @@ def _scatter_nd(index, src, n_rows):
2828
if ndim > 1:
2929
new_idx = index * stride + sum(offsets)
3030
else:
31-
new_idx = index
31+
new_idx = index
3232
src = src.reshape(-1)
3333
new_idx = new_idx.reshape(-1)
3434
rst = np.zeros((stride * n_rows,), dtype=src.dtype)
@@ -328,3 +328,35 @@ def backward(self, grad_out):
328328
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
329329
softmax_op = EdgeSoftmax(gidx, eids, norm_by)
330330
return softmax_op(logits)
331+
332+
333+
class SegmentReduce(mx.autograd.Function):
334+
def __init__(self, op, offsets):
335+
super(SegmentReduce, self).__init__()
336+
self.op = op
337+
self.offsets = offsets
338+
339+
def forward(self, x):
340+
y, arg = _segment_reduce(self.op, x, self.offsets)
341+
self.save_for_backward(arg)
342+
return y
343+
344+
def backward(self, dy):
345+
arg, = self.saved_tensors
346+
offsets = self.offsets
347+
m = offsets[-1].asscalar()
348+
if self.op == 'sum':
349+
offsets_np = asnumpy(offsets[1:-1])
350+
indices_np = np.zeros((m,), dtype=offsets_np.dtype)
351+
np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))
352+
indices_np = np.cumsum(indices_np, -1)
353+
indices = zerocopy_from_numpy(indices_np)
354+
dx = dy[indices]
355+
else:
356+
dx = _bwd_segment_cmp(dy, arg, m)
357+
return dx
358+
359+
360+
def segment_reduce(op, x, offsets):
361+
segment_reduce_op = SegmentReduce(op, offsets)
362+
return segment_reduce_op(x)

python/dgl/backend/mxnet/tensor.py

+3
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ def sum(input, dim, keepdims=False):
152152
def reduce_sum(input):
153153
return input.sum()
154154

155+
def cumsum(input, dim):
156+
return nd.cumsum(input, axis=dim)
157+
155158
def mean(input, dim):
156159
return nd.mean(input, axis=dim)
157160

python/dgl/backend/pytorch/sparse.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch as th
22
from ...base import is_all, ALL
3-
from ...sparse import _gspmm, _gsddmm
3+
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
44

5-
__all__ = ['gspmm', 'gsddmm', 'edge_softmax']
5+
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
66

77

88
def _reduce_grad(grad, shape):
@@ -231,6 +231,32 @@ def backward(ctx, grad_out):
231231
return None, grad_score, None, None
232232

233233

234+
class SegmentReduce(th.autograd.Function):
235+
@staticmethod
236+
def forward(ctx, op, x, offsets):
237+
y, arg = _segment_reduce(op, x, offsets)
238+
print(arg)
239+
ctx.save_for_backward(arg, offsets)
240+
ctx.backward_cache = op
241+
return y
242+
243+
@staticmethod
244+
def backward(ctx, dy):
245+
op = ctx.backward_cache
246+
arg, offsets = ctx.saved_tensors
247+
m = offsets[-1].item()
248+
if op == 'sum':
249+
offsets = offsets[1:-1]
250+
indices = th.zeros(
251+
(m,), device=offsets.device, dtype=offsets.dtype)
252+
indices.scatter_add_(0, offsets, th.ones_like(offsets))
253+
indices = th.cumsum(indices, -1)
254+
dx = dy[indices]
255+
else:
256+
dx = _bwd_segment_cmp(dy, arg, m)
257+
return None, dx, None
258+
259+
234260
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
235261
return GSpMM.apply(gidx, op, reduce_op, lhs_data, rhs_data)
236262

@@ -241,3 +267,7 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
241267

242268
def edge_softmax(gidx, logits, eids=ALL, norm_by='dst'):
243269
return EdgeSoftmax.apply(gidx, logits, eids, norm_by)
270+
271+
272+
def segment_reduce(op, x, offsets):
273+
return SegmentReduce.apply(op, x, offsets)

python/dgl/backend/pytorch/tensor.py

+3
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ def sum(input, dim, keepdims=False):
120120
def reduce_sum(input):
121121
return input.sum()
122122

123+
def cumsum(input, dim):
124+
return th.cumsum(input, dim=dim)
125+
123126
def mean(input, dim):
124127
return th.mean(input, dim=dim)
125128

python/dgl/backend/tensorflow/sparse.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import tensorflow as tf
22
import numpy as np
3-
from .tensor import tensor, copy_to, context
3+
from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy
44
from ...base import is_all, ALL
5-
from ...sparse import _gspmm, _gsddmm
5+
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
66

7-
__all__ = ['gspmm', 'gsddmm', 'edge_softmax']
7+
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']
88

99

1010
def _scatter_nd(index, src, n_rows):
@@ -254,3 +254,28 @@ def _lambda(logits):
254254
return edge_softmax_real(gidx, logits, eids, norm_by)
255255
return _lambda(logits)
256256

257+
258+
def segment_reduce_real(op, x, offsets):
259+
y, arg = _segment_reduce(op, x, offsets)
260+
261+
def segment_reduce_backward(dy):
262+
m = x.shape[0]
263+
if op == 'sum':
264+
offsets_np = asnumpy(offsets[1:-1])
265+
indices_np = np.zeros((m,), dtype=offsets_np.dtype)
266+
np.add.at(indices_np, offsets_np, np.ones_like(offsets_np))
267+
indices_np = np.cumsum(indices_np, -1)
268+
indices = zerocopy_from_numpy(indices_np)
269+
dx = tf.gather(dy, indices)
270+
else:
271+
dx = _bwd_segment_cmp(dy, arg, m)
272+
return dx
273+
274+
return y, segment_reduce_backward
275+
276+
277+
def segment_reduce(op, x, offsets):
278+
@tf.custom_gradient
279+
def _lambda(x):
280+
return segment_reduce_real(op, x, offsets)
281+
return _lambda(x)

python/dgl/backend/tensorflow/tensor.py

+6
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,12 @@ def reduce_sum(input):
175175
return tf.reduce_sum(input)
176176

177177

178+
def cumsum(input, dim):
179+
if input.dtype == tf.bool:
180+
input = tf.cast(input, tf.int32)
181+
return tf.cumsum(input, axis=dim)
182+
183+
178184
def mean(input, dim):
179185
return tf.reduce_mean(input, axis=dim)
180186

python/dgl/ops/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .spmm import *
33
from .sddmm import *
44
from .edge_softmax import *
5+
from .segment import *

python/dgl/ops/segment.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
from ..base import DGLError
44
from .. import backend as F
5-
from .. import convert
6-
from .. import function as fn
75

86

97
def segment_reduce(seglen, value, reducer='sum'):
@@ -41,20 +39,21 @@ def segment_reduce(seglen, value, reducer='sum'):
4139
[5., 5., 5.],
4240
[4., 4., 4.]])
4341
"""
44-
ctx = F.context(seglen)
45-
# TODO(minjie): a more efficient implementation is to create a graph
46-
# directly from a CSR structure.
47-
u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx)
48-
v = F.repeat(F.copy_to(F.arange(0, len(seglen), F.int32), ctx),
49-
seglen, dim=0)
50-
if len(u) != len(v):
51-
raise DGLError("Invalid seglen array:", seglen,
52-
". Its summation must be equal to value.shape[0].")
53-
num_nodes = {'_U': len(u), '_V': len(seglen)}
54-
g = convert.heterograph({('_U', '_E', '_V'): (u, v)}, num_nodes_dict=num_nodes)
55-
g.srcdata['h'] = value
56-
g.update_all(fn.copy_u('h', 'm'), getattr(fn, reducer)('m', 'h'))
57-
return g.dstdata['h']
42+
offsets = F.cumsum(
43+
F.cat([F.zeros((1,), F.dtype(seglen), F.context(seglen)), seglen], 0), 0)
44+
if reducer == 'mean':
45+
rst = F.segment_reduce('sum', value, offsets)
46+
rst_shape = F.shape(rst)
47+
z = F.astype(F.clamp(seglen, 1, len(value)), F.dtype(rst))
48+
z_shape = (rst_shape[0],) + (1,) * (len(rst_shape) - 1)
49+
return rst / F.reshape(z, z_shape)
50+
elif reducer in ['min', 'sum', 'max']:
51+
rst = F.segment_reduce(reducer, value, offsets)
52+
if reducer in ['min', 'max']:
53+
rst = F.replace_inf_with_zero(rst)
54+
return rst
55+
else:
56+
raise DGLError("reducer {} not recognized.".format(reducer))
5857

5958

6059
def segment_softmax(seglen, value):

python/dgl/sparse.py

+77
Original file line numberDiff line numberDiff line change
@@ -248,4 +248,81 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
248248
return out
249249

250250

251+
def _segment_reduce(op, feat, offsets):
252+
r"""Segment reduction operator.
253+
254+
It aggregates the value tensor along the first dimension by segments.
255+
The first argument ``seglen`` stores the length of each segment. Its
256+
summation must be equal to the first dimension of the ``value`` tensor.
257+
Zero-length segments are allowed.
258+
259+
Parameters
260+
----------
261+
op : str
262+
Aggregation method. Can be 'sum', 'max', 'min'.
263+
seglen : Tensor
264+
Segment lengths.
265+
value : Tensor
266+
Value to aggregate.
267+
268+
Returns
269+
-------
270+
tuple(Tensor)
271+
The first tensor correspond to aggregated tensor of shape
272+
``(len(seglen), value.shape[1:])``, and the second tensor records
273+
the argmin/max at each position for computing gradients.
274+
275+
Notes
276+
-----
277+
This function does not handle gradients.
278+
"""
279+
n = F.shape(offsets)[0] - 1
280+
out_shp = (n,) + F.shape(feat)[1:]
281+
ctx = F.context(feat)
282+
dtype = F.dtype(feat)
283+
idtype = F.dtype(offsets)
284+
out = F.zeros(out_shp, dtype, ctx)
285+
arg = None
286+
if op in ['min', 'max']:
287+
arg = F.zeros(out_shp, idtype, ctx)
288+
arg_nd = to_dgl_nd_for_write(arg)
289+
_CAPI_DGLKernelSegmentReduce(op,
290+
to_dgl_nd(feat),
291+
to_dgl_nd(offsets),
292+
to_dgl_nd_for_write(out),
293+
arg_nd)
294+
arg = None if arg is None else F.zerocopy_from_dgl_ndarray(arg_nd)
295+
return out, arg
296+
297+
298+
def _bwd_segment_cmp(feat, arg, m):
299+
r""" Backward phase of segment reduction (for 'min'/'max' reduction).
300+
301+
It computes the gradient of input feature given output gradient of
302+
the segment reduction result.
303+
304+
Parameters
305+
----------
306+
feat : Tensor
307+
The output gradient
308+
arg : Tensor
309+
The ArgMin/Max tensor produced by segment_reduce op.
310+
m : int
311+
The length of input gradients' first dimension.
312+
313+
Returns
314+
-------
315+
Tensor
316+
The input gradient.
317+
"""
318+
out_shp = (m,) + F.shape(feat)[1:]
319+
ctx = F.context(feat)
320+
dtype = F.dtype(feat)
321+
out = F.zeros(out_shp, dtype, ctx)
322+
_CAPI_DGLKernelBwdSegmentCmp(to_dgl_nd(feat),
323+
to_dgl_nd(arg),
324+
to_dgl_nd_for_write(out))
325+
return out
326+
327+
251328
_init_api("dgl.sparse")

0 commit comments

Comments
 (0)