Skip to content

Commit 1e5ff17

Browse files
mufeiliUbuntu
authored andcommitted
[Dataset] CornellDataset and TexasDataset (dmlc#5513)
Co-authored-by: Ubuntu <[email protected]>
1 parent d32612e commit 1e5ff17

File tree

4 files changed

+203
-13
lines changed

4 files changed

+203
-13
lines changed

docs/source/api/python/dgl.data.rst

+2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ Datasets for node classification/regression tasks
5959
ChameleonDataset
6060
SquirrelDataset
6161
ActorDataset
62+
CornellDataset
63+
TexasDataset
6264

6365
Edge Prediction Datasets
6466
---------------------------------------

python/dgl/data/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,13 @@
5454
from .tu import LegacyTUDataset, TUDataset
5555
from .utils import *
5656
from .cluster import CLUSTERDataset
57+
from .geom_gcn import (
58+
ChameleonDataset,
59+
CornellDataset,
60+
SquirrelDataset,
61+
TexasDataset,
62+
)
5763
from .pattern import PATTERNDataset
58-
from .wiki_network import ChameleonDataset, SquirrelDataset
5964
from .wikics import WikiCSDataset
6065
from .yelp import YelpDataset
6166
from .zinc import ZINCDataset

python/dgl/data/wiki_network.py python/dgl/data/geom_gcn.py

+161-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
"""
2-
Wikipedia page-page networks on two topics: chameleons and squirrels.
3-
"""
1+
"""Datasets introduced in the Geom-GCN paper."""
42
import os
53

64
import numpy as np
@@ -10,11 +8,10 @@
108
from .utils import _get_dgl_url
119

1210

13-
class WikiNetworkDataset(DGLBuiltinDataset):
14-
r"""Wikipedia page-page networks from `Multi-scale Attributed
15-
Node Embedding <https://arxiv.org/abs/1909.13021>`__ and later modified by
11+
class GeomGCNDataset(DGLBuiltinDataset):
12+
r"""Datasets introduced in
1613
`Geom-GCN: Geometric Graph Convolutional Networks
17-
<https://arxiv.org/abs/2002.05287>`
14+
<https://arxiv.org/abs/2002.05287>`__
1815
1916
Parameters
2017
----------
@@ -34,7 +31,7 @@ class WikiNetworkDataset(DGLBuiltinDataset):
3431

3532
def __init__(self, name, raw_dir, force_reload, verbose, transform):
3633
url = _get_dgl_url(f"dataset/{name}.zip")
37-
super(WikiNetworkDataset, self).__init__(
34+
super(GeomGCNDataset, self).__init__(
3835
name=name,
3936
url=url,
4037
raw_dir=raw_dir,
@@ -106,11 +103,11 @@ def num_classes(self):
106103
return self._num_classes
107104

108105

109-
class ChameleonDataset(WikiNetworkDataset):
106+
class ChameleonDataset(GeomGCNDataset):
110107
r"""Wikipedia page-page network on chameleons from `Multi-scale Attributed
111108
Node Embedding <https://arxiv.org/abs/1909.13021>`__ and later modified by
112109
`Geom-GCN: Geometric Graph Convolutional Networks
113-
<https://arxiv.org/abs/2002.05287>`
110+
<https://arxiv.org/abs/2002.05287>`__
114111
115112
Nodes represent articles from the English Wikipedia, edges reflect mutual
116113
links between them. Node features indicate the presence of particular nouns
@@ -182,11 +179,11 @@ def __init__(
182179
)
183180

184181

185-
class SquirrelDataset(WikiNetworkDataset):
182+
class SquirrelDataset(GeomGCNDataset):
186183
r"""Wikipedia page-page network on squirrels from `Multi-scale Attributed
187184
Node Embedding <https://arxiv.org/abs/1909.13021>`__ and later modified by
188185
`Geom-GCN: Geometric Graph Convolutional Networks
189-
<https://arxiv.org/abs/2002.05287>`
186+
<https://arxiv.org/abs/2002.05287>`__
190187
191188
Nodes represent articles from the English Wikipedia, edges reflect mutual
192189
links between them. Node features indicate the presence of particular nouns
@@ -256,3 +253,155 @@ def __init__(
256253
verbose=verbose,
257254
transform=transform,
258255
)
256+
257+
258+
class CornellDataset(GeomGCNDataset):
259+
r"""Cornell subset of
260+
`WebKB <http://www.cs.cmu.edu/afs/cs.cmu.edu/project/theo-11/www/wwkb/>`__,
261+
later modified by `Geom-GCN: Geometric Graph Convolutional Networks
262+
<https://arxiv.org/abs/2002.05287>`__
263+
264+
Nodes represent web pages. Edges represent hyperlinks between them. Node
265+
features are the bag-of-words representation of web pages. The web pages
266+
are manually classified into the five categories, student, project, course,
267+
staff, and faculty.
268+
269+
Statistics:
270+
271+
- Nodes: 183
272+
- Edges: 298
273+
- Number of Classes: 5
274+
- 10 train/val/test splits
275+
276+
- Train: 87
277+
- Val: 59
278+
- Test: 37
279+
280+
Parameters
281+
----------
282+
raw_dir : str, optional
283+
Raw file directory to store the processed data. Default: ~/.dgl/
284+
force_reload : bool, optional
285+
Whether to re-download the data source. Default: False
286+
verbose : bool, optional
287+
Whether to print progress information. Default: True
288+
transform : callable, optional
289+
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
290+
a transformed version. The :class:`~dgl.DGLGraph` object will be
291+
transformed before every access. Default: None
292+
293+
Attributes
294+
----------
295+
num_classes : int
296+
Number of node classes
297+
298+
Notes
299+
-----
300+
The graph does not come with edges for both directions.
301+
302+
Examples
303+
--------
304+
305+
>>> from dgl.data import CornellDataset
306+
>>> dataset = CornellDataset()
307+
>>> g = dataset[0]
308+
>>> num_classes = dataset.num_classes
309+
310+
>>> # get node features
311+
>>> feat = g.ndata["feat"]
312+
313+
>>> # get data split
314+
>>> train_mask = g.ndata["train_mask"]
315+
>>> val_mask = g.ndata["val_mask"]
316+
>>> test_mask = g.ndata["test_mask"]
317+
318+
>>> # get labels
319+
>>> label = g.ndata['label']
320+
"""
321+
322+
def __init__(
323+
self, raw_dir=None, force_reload=False, verbose=True, transform=None
324+
):
325+
super(CornellDataset, self).__init__(
326+
name="cornell",
327+
raw_dir=raw_dir,
328+
force_reload=force_reload,
329+
verbose=verbose,
330+
transform=transform,
331+
)
332+
333+
334+
class TexasDataset(GeomGCNDataset):
335+
r"""Texas subset of
336+
`WebKB <http://www.cs.cmu.edu/afs/cs.cmu.edu/project/theo-11/www/wwkb/>`__,
337+
later modified by `Geom-GCN: Geometric Graph Convolutional Networks
338+
<https://arxiv.org/abs/2002.05287>`__
339+
340+
Nodes represent web pages. Edges represent hyperlinks between them. Node
341+
features are the bag-of-words representation of web pages. The web pages
342+
are manually classified into the five categories, student, project, course,
343+
staff, and faculty.
344+
345+
Statistics:
346+
347+
- Nodes: 183
348+
- Edges: 325
349+
- Number of Classes: 5
350+
- 10 train/val/test splits
351+
352+
- Train: 87
353+
- Val: 59
354+
- Test: 37
355+
356+
Parameters
357+
----------
358+
raw_dir : str, optional
359+
Raw file directory to store the processed data. Default: ~/.dgl/
360+
force_reload : bool, optional
361+
Whether to re-download the data source. Default: False
362+
verbose : bool, optional
363+
Whether to print progress information. Default: True
364+
transform : callable, optional
365+
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
366+
a transformed version. The :class:`~dgl.DGLGraph` object will be
367+
transformed before every access. Default: None
368+
369+
Attributes
370+
----------
371+
num_classes : int
372+
Number of node classes
373+
374+
Notes
375+
-----
376+
The graph does not come with edges for both directions.
377+
378+
Examples
379+
--------
380+
381+
>>> from dgl.data import TexasDataset
382+
>>> dataset = TexasDataset()
383+
>>> g = dataset[0]
384+
>>> num_classes = dataset.num_classes
385+
386+
>>> # get node features
387+
>>> feat = g.ndata["feat"]
388+
389+
>>> # get data split
390+
>>> train_mask = g.ndata["train_mask"]
391+
>>> val_mask = g.ndata["val_mask"]
392+
>>> test_mask = g.ndata["test_mask"]
393+
394+
>>> # get labels
395+
>>> label = g.ndata['label']
396+
"""
397+
398+
def __init__(
399+
self, raw_dir=None, force_reload=False, verbose=True, transform=None
400+
):
401+
super(TexasDataset, self).__init__(
402+
name="texas",
403+
raw_dir=raw_dir,
404+
force_reload=force_reload,
405+
verbose=verbose,
406+
transform=transform,
407+
)

tests/python/common/data/test_wiki_network.py tests/python/common/data/test_geom_gcn.py

+34
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,37 @@ def test_squirrel():
3737
assert g.num_edges() == 217073
3838
g2 = dgl.data.SquirrelDataset(force_reload=True, transform=transform)[0]
3939
assert g2.num_edges() - g.num_edges() == g.num_nodes()
40+
41+
42+
@unittest.skipIf(
43+
F._default_context_str == "gpu",
44+
reason="Datasets don't need to be tested on GPU.",
45+
)
46+
@unittest.skipIf(
47+
dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
48+
)
49+
def test_cornell():
50+
transform = dgl.AddSelfLoop(allow_duplicate=True)
51+
52+
g = dgl.data.CornellDataset(force_reload=True)[0]
53+
assert g.num_nodes() == 183
54+
assert g.num_edges() == 298
55+
g2 = dgl.data.CornellDataset(force_reload=True, transform=transform)[0]
56+
assert g2.num_edges() - g.num_edges() == g.num_nodes()
57+
58+
59+
@unittest.skipIf(
60+
F._default_context_str == "gpu",
61+
reason="Datasets don't need to be tested on GPU.",
62+
)
63+
@unittest.skipIf(
64+
dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
65+
)
66+
def test_texas():
67+
transform = dgl.AddSelfLoop(allow_duplicate=True)
68+
69+
g = dgl.data.TexasDataset(force_reload=True)[0]
70+
assert g.num_nodes() == 183
71+
assert g.num_edges() == 325
72+
g2 = dgl.data.TexasDataset(force_reload=True, transform=transform)[0]
73+
assert g2.num_edges() - g.num_edges() == g.num_nodes()

0 commit comments

Comments
 (0)