1
1
import mxnet as mx
2
2
import numpy as np
3
3
from mxnet import nd
4
- from ...sparse import _gspmm , _gsddmm
4
+ from ...sparse import _gspmm , _gsddmm , _segment_reduce , _bwd_segment_cmp
5
5
from ...base import dgl_warning , is_all , ALL
6
6
from .tensor import asnumpy , copy_to , zerocopy_from_numpy , context , to_backend_ctx
7
7
8
- __all__ = ['gspmm' , 'gsddmm' , 'edge_softmax' ]
8
+ __all__ = ['gspmm' , 'gsddmm' , 'edge_softmax' , 'segment_reduce' ]
9
9
10
10
11
11
def _scatter_nd (index , src , n_rows ):
@@ -28,7 +28,7 @@ def _scatter_nd(index, src, n_rows):
28
28
if ndim > 1 :
29
29
new_idx = index * stride + sum (offsets )
30
30
else :
31
- new_idx = index
31
+ new_idx = index
32
32
src = src .reshape (- 1 )
33
33
new_idx = new_idx .reshape (- 1 )
34
34
rst = np .zeros ((stride * n_rows ,), dtype = src .dtype )
@@ -328,3 +328,35 @@ def backward(self, grad_out):
328
328
def edge_softmax (gidx , logits , eids = ALL , norm_by = 'dst' ):
329
329
softmax_op = EdgeSoftmax (gidx , eids , norm_by )
330
330
return softmax_op (logits )
331
+
332
+
333
+ class SegmentReduce (mx .autograd .Function ):
334
+ def __init__ (self , op , offsets ):
335
+ super (SegmentReduce , self ).__init__ ()
336
+ self .op = op
337
+ self .offsets = offsets
338
+
339
+ def forward (self , x ):
340
+ y , arg = _segment_reduce (self .op , x , self .offsets )
341
+ self .save_for_backward (arg )
342
+ return y
343
+
344
+ def backward (self , dy ):
345
+ arg , = self .saved_tensors
346
+ offsets = self .offsets
347
+ m = offsets [- 1 ].asscalar ()
348
+ if self .op == 'sum' :
349
+ offsets_np = asnumpy (offsets [1 :- 1 ])
350
+ indices_np = np .zeros ((m ,), dtype = offsets_np .dtype )
351
+ np .add .at (indices_np , offsets_np , np .ones_like (offsets_np ))
352
+ indices_np = np .cumsum (indices_np , - 1 )
353
+ indices = zerocopy_from_numpy (indices_np )
354
+ dx = dy [indices ]
355
+ else :
356
+ dx = _bwd_segment_cmp (dy , arg , m )
357
+ return dx
358
+
359
+
360
+ def segment_reduce (op , x , offsets ):
361
+ segment_reduce_op = SegmentReduce (op , offsets )
362
+ return segment_reduce_op (x )
0 commit comments