-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
dataset.py
448 lines (369 loc) · 16.4 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
import copy
import os.path as osp
import re
import sys
import warnings
from collections.abc import Sequence
from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
)
import numpy as np
import torch.utils.data
from torch import Tensor
from torch_geometric.data.data import BaseData
from torch_geometric.io import fs
IndexType = Union[slice, Tensor, np.ndarray, Sequence]
MISSING = '???'
class Dataset(torch.utils.data.Dataset):
r"""Dataset base class for creating graph datasets.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/
create_dataset.html>`__ for the accompanying tutorial.
Args:
root (str, optional): Root directory where the dataset should be saved.
(optional: :obj:`None`)
transform (callable, optional): A function/transform that takes in a
:class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` object and returns a
transformed version.
The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
a :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` object and returns a
transformed version.
The data object will be transformed before being saved to disk.
(default: :obj:`None`)
pre_filter (callable, optional): A function that takes in a
:class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` object and returns a
boolean value, indicating whether the data object should be
included in the final dataset. (default: :obj:`None`)
log (bool, optional): Whether to print any console output while
downloading and processing the dataset. (default: :obj:`True`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
@property
def raw_file_names(self) -> Union[str, List[str], Tuple[str, ...]]:
r"""The name of the files in the :obj:`self.raw_dir` folder that must
be present in order to skip downloading.
"""
raise NotImplementedError
@property
def processed_file_names(self) -> Union[str, List[str], Tuple[str, ...]]:
r"""The name of the files in the :obj:`self.processed_dir` folder that
must be present in order to skip processing.
"""
raise NotImplementedError
def download(self) -> None:
r"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
raise NotImplementedError
def process(self) -> None:
r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
raise NotImplementedError
def len(self) -> int:
r"""Returns the number of data objects stored in the dataset."""
raise NotImplementedError
def get(self, idx: int) -> BaseData:
r"""Gets the data object at index :obj:`idx`."""
raise NotImplementedError
def __init__(
self,
root: Optional[str] = None,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
log: bool = True,
force_reload: bool = False,
) -> None:
super().__init__()
if isinstance(root, str):
root = osp.expanduser(fs.normpath(root))
self.root = root or MISSING
self.transform = transform
self.pre_transform = pre_transform
self.pre_filter = pre_filter
self.log = log
self._indices: Optional[Sequence] = None
self.force_reload = force_reload
if self.has_download:
self._download()
if self.has_process:
self._process()
def indices(self) -> Sequence:
return range(self.len()) if self._indices is None else self._indices
@property
def raw_dir(self) -> str:
return osp.join(self.root, 'raw')
@property
def processed_dir(self) -> str:
return osp.join(self.root, 'processed')
@property
def num_node_features(self) -> int:
r"""Returns the number of features per node in the dataset."""
data = self[0]
# Do not fill cache for `InMemoryDataset`:
if hasattr(self, '_data_list') and self._data_list is not None:
self._data_list[0] = None
data = data[0] if isinstance(data, tuple) else data
if hasattr(data, 'num_node_features'):
return data.num_node_features
raise AttributeError(f"'{data.__class__.__name__}' object has no "
f"attribute 'num_node_features'")
@property
def num_features(self) -> int:
r"""Returns the number of features per node in the dataset.
Alias for :py:attr:`~num_node_features`.
"""
return self.num_node_features
@property
def num_edge_features(self) -> int:
r"""Returns the number of features per edge in the dataset."""
data = self[0]
# Do not fill cache for `InMemoryDataset`:
if hasattr(self, '_data_list') and self._data_list is not None:
self._data_list[0] = None
data = data[0] if isinstance(data, tuple) else data
if hasattr(data, 'num_edge_features'):
return data.num_edge_features
raise AttributeError(f"'{data.__class__.__name__}' object has no "
f"attribute 'num_edge_features'")
def _infer_num_classes(self, y: Optional[Tensor]) -> int:
if y is None:
return 0
elif y.numel() == y.size(0) and not torch.is_floating_point(y):
return int(y.max()) + 1
elif y.numel() == y.size(0) and torch.is_floating_point(y):
num_classes = torch.unique(y).numel()
if num_classes > 2:
warnings.warn("Found floating-point labels while calling "
"`dataset.num_classes`. Returning the number of "
"unique elements. Please make sure that this "
"is expected before proceeding.")
return num_classes
else:
return y.size(-1)
@property
def num_classes(self) -> int:
r"""Returns the number of classes in the dataset."""
# We iterate over the dataset and collect all labels to determine the
# maximum number of classes. Importantly, in rare cases, `__getitem__`
# may produce a tuple of data objects (e.g., when used in combination
# with `RandomLinkSplit`, so we take care of this case here as well:
data_list = _get_flattened_data_list([data for data in self])
if 'y' in data_list[0] and isinstance(data_list[0].y, Tensor):
y = torch.cat([data.y for data in data_list if 'y' in data], dim=0)
else:
y = torch.as_tensor([data.y for data in data_list if 'y' in data])
# Do not fill cache for `InMemoryDataset`:
if hasattr(self, '_data_list') and self._data_list is not None:
self._data_list = self.len() * [None]
return self._infer_num_classes(y)
@property
def raw_paths(self) -> List[str]:
r"""The absolute filepaths that must be present in order to skip
downloading.
"""
files = self.raw_file_names
# Prevent a common source of error in which `file_names` are not
# defined as a property.
if isinstance(files, Callable):
files = files()
return [osp.join(self.raw_dir, f) for f in to_list(files)]
@property
def processed_paths(self) -> List[str]:
r"""The absolute filepaths that must be present in order to skip
processing.
"""
files = self.processed_file_names
# Prevent a common source of error in which `file_names` are not
# defined as a property.
if isinstance(files, Callable):
files = files()
return [osp.join(self.processed_dir, f) for f in to_list(files)]
@property
def has_download(self) -> bool:
r"""Checks whether the dataset defines a :meth:`download` method."""
return overrides_method(self.__class__, 'download')
def _download(self):
if files_exist(self.raw_paths): # pragma: no cover
return
fs.makedirs(self.raw_dir, exist_ok=True)
self.download()
@property
def has_process(self) -> bool:
r"""Checks whether the dataset defines a :meth:`process` method."""
return overrides_method(self.__class__, 'process')
def _process(self):
f = osp.join(self.processed_dir, 'pre_transform.pt')
if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
self.pre_transform):
warnings.warn(
"The `pre_transform` argument differs from the one used in "
"the pre-processed version of this dataset. If you want to "
"make use of another pre-processing technique, pass "
"`force_reload=True` explicitly to reload the dataset.")
f = osp.join(self.processed_dir, 'pre_filter.pt')
if osp.exists(f) and torch.load(f, weights_only=False) != _repr(
self.pre_filter):
warnings.warn(
"The `pre_filter` argument differs from the one used in "
"the pre-processed version of this dataset. If you want to "
"make use of another pre-fitering technique, pass "
"`force_reload=True` explicitly to reload the dataset.")
if not self.force_reload and files_exist(self.processed_paths):
return
if self.log and 'pytest' not in sys.modules:
print('Processing...', file=sys.stderr)
fs.makedirs(self.processed_dir, exist_ok=True)
self.process()
path = osp.join(self.processed_dir, 'pre_transform.pt')
fs.torch_save(_repr(self.pre_transform), path)
path = osp.join(self.processed_dir, 'pre_filter.pt')
fs.torch_save(_repr(self.pre_filter), path)
if self.log and 'pytest' not in sys.modules:
print('Done!', file=sys.stderr)
def __len__(self) -> int:
r"""The number of examples in the dataset."""
return len(self.indices())
def __getitem__(
self,
idx: Union[int, np.integer, IndexType],
) -> Union['Dataset', BaseData]:
r"""In case :obj:`idx` is of type integer, will return the data object
at index :obj:`idx` (and transforms it in case :obj:`transform` is
present).
In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or
bool, will return a subset of the dataset at the specified indices.
"""
if (isinstance(idx, (int, np.integer))
or (isinstance(idx, Tensor) and idx.dim() == 0)
or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
data = self.get(self.indices()[idx])
data = data if self.transform is None else self.transform(data)
return data
else:
return self.index_select(idx)
def __iter__(self) -> Iterator[BaseData]:
for i in range(len(self)):
yield self[i]
def index_select(self, idx: IndexType) -> 'Dataset':
r"""Creates a subset of the dataset from specified indices :obj:`idx`.
Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a
list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type
long or bool.
"""
indices = self.indices()
if isinstance(idx, slice):
start, stop, step = idx.start, idx.stop, idx.step
# Allow floating-point slicing, e.g., dataset[:0.9]
if isinstance(start, float):
start = round(start * len(self))
if isinstance(stop, float):
stop = round(stop * len(self))
idx = slice(start, stop, step)
indices = indices[idx]
elif isinstance(idx, Tensor) and idx.dtype == torch.long:
return self.index_select(idx.flatten().tolist())
elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
idx = idx.flatten().nonzero(as_tuple=False)
return self.index_select(idx.flatten().tolist())
elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
return self.index_select(idx.flatten().tolist())
elif isinstance(idx, np.ndarray) and idx.dtype == bool:
idx = idx.flatten().nonzero()[0]
return self.index_select(idx.flatten().tolist())
elif isinstance(idx, Sequence) and not isinstance(idx, str):
indices = [indices[i] for i in idx]
else:
raise IndexError(
f"Only slices (':'), list, tuples, torch.tensor and "
f"np.ndarray of dtype long or bool are valid indices (got "
f"'{type(idx).__name__}')")
dataset = copy.copy(self)
dataset._indices = indices
return dataset
def shuffle(
self,
return_perm: bool = False,
) -> Union['Dataset', Tuple['Dataset', Tensor]]:
r"""Randomly shuffles the examples in the dataset.
Args:
return_perm (bool, optional): If set to :obj:`True`, will also
return the random permutation used to shuffle the dataset.
(default: :obj:`False`)
"""
perm = torch.randperm(len(self))
dataset = self.index_select(perm)
return (dataset, perm) if return_perm is True else dataset
def __repr__(self) -> str:
arg_repr = str(len(self)) if len(self) > 1 else ''
return f'{self.__class__.__name__}({arg_repr})'
def get_summary(self) -> Any:
r"""Collects summary statistics for the dataset."""
from torch_geometric.data.summary import Summary
return Summary.from_dataset(self)
def print_summary(self, fmt: str = "psql") -> None:
r"""Prints summary statistics of the dataset to the console.
Args:
fmt (str, optional): Summary tables format. Available table formats
can be found `here <https://github.com/astanin/python-tabulate?
tab=readme-ov-file#table-format>`__. (default: :obj:`"psql"`)
"""
print(self.get_summary().format(fmt=fmt))
def to_datapipe(self) -> Any:
r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`.
The returned instance can then be used with :pyg:`PyG's` built-in
:class:`DataPipes` for baching graphs as follows:
.. code-block:: python
from torch_geometric.datasets import QM9
dp = QM9(root='./data/QM9/').to_datapipe()
dp = dp.batch_graphs(batch_size=2, drop_last=True)
for batch in dp:
pass
See the `PyTorch tutorial
<https://pytorch.org/data/main/tutorial.html>`_ for further background
on DataPipes.
"""
from torch_geometric.data.datapipes import DatasetAdapter
return DatasetAdapter(self)
def overrides_method(cls, method_name: str) -> bool:
from torch_geometric.data import InMemoryDataset
if method_name in cls.__dict__:
return True
out = False
for base in cls.__bases__:
if base != Dataset and base != InMemoryDataset:
out |= overrides_method(base, method_name)
return out
def to_list(value: Any) -> Sequence:
if isinstance(value, Sequence) and not isinstance(value, str):
return value
else:
return [value]
def files_exist(files: List[str]) -> bool:
# NOTE: We return `False` in case `files` is empty, leading to a
# re-processing of files on every instantiation.
return len(files) != 0 and all([fs.exists(f) for f in files])
def _repr(obj: Any) -> str:
if obj is None:
return 'None'
return re.sub('(<.*?)\\s.*(>)', r'\1\2', str(obj))
def _get_flattened_data_list(data_list: Iterable[Any]) -> List[BaseData]:
outs: List[BaseData] = []
for data in data_list:
if isinstance(data, BaseData):
outs.append(data)
elif isinstance(data, (tuple, list)):
outs.extend(_get_flattened_data_list(data))
elif isinstance(data, dict):
outs.extend(_get_flattened_data_list(data.values()))
return outs