Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Pytorch Geometric integration #73

Merged
merged 69 commits into from
Jul 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
ea4e7b6
Initial structure of GraphClassification model.py
Feb 2, 2021
1255f8f
Improvement of model.py. Still need to debug etc
Feb 2, 2021
365863f
BasicDataset Implemented
Feb 6, 2021
02b0f6a
Create __init__.py
Feb 6, 2021
f28e949
Implemented dataset and DataModule as for image processing
Feb 7, 2021
ad76827
Pipeline taken from images.
Feb 8, 2021
ea6ee9d
Initial structure of GraphClassification model.py
Feb 2, 2021
8b93a4a
Improvement of model.py. Still need to debug etc
Feb 2, 2021
6b4d7e3
BasicDataset Implemented
Feb 6, 2021
48dcf2d
Implemented dataset and DataModule as for image processing
Feb 7, 2021
49dfe4d
Pipeline taken from images.
Feb 8, 2021
151f7d9
Choice of model implemented (you can pass a model to GraphClassifier)
Mar 21, 2021
6236d95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
28e315f
Initial readaptation of the structure
May 14, 2021
08ace7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
35a79cb
Minimal structure of how to structure data.py files
May 14, 2021
93dd638
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
82576a5
Minor corrections
May 17, 2021
089cb07
update
tchaton May 17, 2021
920fc68
i
tchaton May 17, 2021
c970b5f
update
tchaton May 17, 2021
fe41405
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2021
868b8d7
Added auto_dataset.num_features
May 24, 2021
debf5da
Deleted manually included num_features so that it is extracted from G…
May 24, 2021
1e6b2b0
Test for GraphClassification implemented
May 30, 2021
faa8709
Documentation for GraphClassification included
May 30, 2021
072a35b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2021
be015df
Creation of from_pygdatasequence method in DataModule and GraphSequen…
Jun 8, 2021
1fb160b
Update graph_classification.py
Jun 7, 2021
bb3b941
Update datatype_graph.txt
Jun 7, 2021
3583d5c
Tests and docs for the from_pygdatasequence method
Jun 8, 2021
193c2bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
f59a2fc
Graph requirements
Jun 8, 2021
71a15bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
370ea24
Update CHANGELOG.md
Jun 8, 2021
e9d4e93
Update requirements with pytorch geometric libraries
Jun 8, 2021
a2b208e
Simplified, version with only the DataSource
Jul 12, 2021
7c3eaf4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
809b615
Minor tweaks
Jul 12, 2021
bf4a9b6
Merge branch 'master' of https://github.com/PabloAMC/lightning-flash
Jul 12, 2021
3089d94
Update the flash_example to reflect the new template
Jul 12, 2021
338c3ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2021
0f39778
Delete IMDB-BINARY_A.txt
Jul 12, 2021
b051f36
Delete IMDB-BINARY_graph_indicator.txt
Jul 12, 2021
4e23aff
Delete IMDB-BINARY_graph_labels.txt
Jul 12, 2021
1ad2ce3
Class method from_pygdatasequence from flash/core/data/data_module.py
Jul 12, 2021
4e9178c
Merge branch 'master' of https://github.com/PabloAMC/lightning-flash
Jul 12, 2021
94d4e60
Merge branch 'master' into master
ethanwharris Jul 14, 2021
73663bd
Update docs
ethanwharris Jul 14, 2021
287132f
Merge branch 'master' into master
ethanwharris Jul 14, 2021
2631bd4
fix imports.py
ethanwharris Jul 14, 2021
b4d3b41
remove unused imports
ethanwharris Jul 14, 2021
b19b5b8
clean init.py
ethanwharris Jul 14, 2021
7a4a914
updates
ethanwharris Jul 14, 2021
13aa012
Updates
ethanwharris Jul 14, 2021
e9cedb0
Updates
ethanwharris Jul 14, 2021
fe95a77
Updates
ethanwharris Jul 14, 2021
bbdad91
Updates
ethanwharris Jul 14, 2021
d5deb38
Update docs
ethanwharris Jul 14, 2021
f634e9f
Update docs
ethanwharris Jul 14, 2021
428f313
Update docs
ethanwharris Jul 14, 2021
435cc95
fix tests
ethanwharris Jul 14, 2021
b54e543
fix tests
ethanwharris Jul 14, 2021
4453818
Add API reference
ethanwharris Jul 14, 2021
b4877b1
Try fix
ethanwharris Jul 14, 2021
8fe2813
Try fix
ethanwharris Jul 14, 2021
113b6d0
Try fix
ethanwharris Jul 14, 2021
d8d26ab
Update flash/core/data/auto_dataset.py
ethanwharris Jul 14, 2021
7b2734f
Update docstring
ethanwharris Jul 14, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ jobs:
python-version: 3.8
requires: 'latest'
topic: ['serve']
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
topic: ['graph']

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
Expand Down Expand Up @@ -109,6 +113,7 @@ jobs:
run: |
python --version
pip --version
pip install torch>=1.8
pip install '.[${{ join(matrix.topic,',') }}]' --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install '.[test]' --pre --upgrade
pip list
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,4 @@ CameraRGB
CameraSeg
jigsaw_toxic_comments
flash_examples/serve/tabular_classification/data
flash_examples/data
5 changes: 2 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for nesting of `Task` objects ([#575](https://github.com/PyTorchLightning/lightning-flash/pull/575))

- Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

- Removed bolts pretrained weights for SSL from ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

### Deprecated


### Fixed


Expand Down
33 changes: 33 additions & 0 deletions docs/source/api/graph.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
###########
flash.graph
###########

.. contents::
:depth: 1
:local:
:backlinks: top

.. currentmodule:: flash.graph

Classification
______________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~classification.model.GraphClassifier
~classification.data.GraphClassificationData

classification.data.GraphClassificationPreprocess

flash.graph.data
________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~data.GraphDatasetDataSource
7 changes: 7 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ Lightning Flash
reference/summarization
reference/translation

.. toctree::
:maxdepth: 1
:caption: Graph

reference/graph_classification

.. toctree::
:maxdepth: 1
:caption: Integrations
Expand All @@ -73,6 +79,7 @@ Lightning Flash
api/tabular
api/text
api/video
api/graph

.. toctree::
:maxdepth: 1
Expand Down
33 changes: 33 additions & 0 deletions docs/source/reference/graph_classification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
.. _graph_classification:

####################
Graph Classification
####################

********
The Task
********
This task consist on classifying graphs.
The task predicts which ‘class’ the graph belongs to.
A class is a label that indicates the kind of graph.
For example, a label may indicate whether one molecule interacts with another.

The :class:`~flash.graph.classification.model.GraphClassifier` and :class:`~flash.graph.classification.data.GraphClassificationData` classes internally rely on `pytorch-geometric <https://github.com/rusty1s/pytorch_geometric>`_.

------

*******
Example
*******

Let's look at the task of classifying graphs from the KKI data set from `TU Dortmund University <https://chrsmrrs.github.io/datasets>`_.

Once we've created the `TUDataset <https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/tu_dataset.html#TUDataset>`, we create the :class:`~flash.graph.classification.data.GraphClassificationData`.
We then create our :class:`~flash.graph.classification.model.GraphClassifier` and train on the KKI data.
Next, we use the trained :class:`~flash.graph.classification.model.GraphClassifier` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/graph_classification.py
:language: python
:lines: 14
5 changes: 5 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def _compare_version(package: str, op, version) -> bool:
_PIL_AVAILABLE = _module_available("PIL")
_ASTEROID_AVAILABLE = _module_available("asteroid")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric")

if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
Expand All @@ -104,6 +107,7 @@ def _compare_version(package: str, op, version) -> bool:
_AUDIO_AVAILABLE = all([
_ASTEROID_AVAILABLE,
])
_GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE

_EXTRAS_AVAILABLE = {
'image': _IMAGE_AVAILABLE,
Expand All @@ -112,6 +116,7 @@ def _compare_version(package: str, op, version) -> bool:
'video': _VIDEO_AVAILABLE,
'serve': _SERVE_AVAILABLE,
'audio': _AUDIO_AVAILABLE,
'graph': _GRAPH_AVAILABLE,
}


Expand Down
1 change: 1 addition & 0 deletions flash/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.graph.classification import GraphClassificationData, GraphClassifier # noqa: F401
2 changes: 2 additions & 0 deletions flash/graph/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.graph.classification.data import GraphClassificationData # noqa: F401
from flash.graph.classification.model import GraphClassifier # noqa: F401
70 changes: 70 additions & 0 deletions flash/graph/classification/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional

from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires_extras
from flash.graph.data import GraphDatasetDataSource

if _GRAPH_AVAILABLE:
from torch_geometric.data.batch import Batch
from torch_geometric.transforms import NormalizeFeatures


class GraphClassificationPreprocess(Preprocess):

@requires_extras("graph")
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
):
super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.DATASET: GraphDatasetDataSource(),
},
default_data_source=DefaultDataSources.DATASET,
)

def get_state_dict(self) -> Dict[str, Any]:
return self.transforms

@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

@staticmethod
def default_transforms() -> Optional[Dict[str, Callable]]:
return {"pre_tensor_transform": NormalizeFeatures(), "collate": Batch.from_data_list}
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved


class GraphClassificationData(DataModule):
"""Data module for graph classification tasks."""

preprocess_cls = GraphClassificationPreprocess

@property
def num_features(self):
n_cls_train = getattr(self.train_dataset, "num_features", None)
n_cls_val = getattr(self.val_dataset, "num_features", None)
n_cls_test = getattr(self.test_dataset, "num_features", None)
return n_cls_train or n_cls_val or n_cls_test
147 changes: 147 additions & 0 deletions flash/graph/classification/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Mapping, Sequence, Type, Union

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Linear

from flash.core.classification import ClassificationTask
from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE

if _TORCH_GEOMETRIC_AVAILABLE:
from torch_geometric.nn import BatchNorm, GCNConv, global_mean_pool, MessagePassing
else:
MessagePassing = None
GCNConv = None


class GraphBlock(nn.Module):

def __init__(self, nc_input, nc_output, conv_cls, act=nn.ReLU(), **conv_kwargs):
super().__init__()
self.conv = conv_cls(nc_input, nc_output, **conv_kwargs)
self.norm = BatchNorm(nc_output)
self.act = act

def forward(self, x, edge_index, edge_weight):
x = self.conv(x, edge_index, edge_weight=edge_weight)
x = self.norm(x)
return self.act(x)


class BaseGraphModel(nn.Module):

def __init__(
self,
num_features: int,
hidden_channels: List[int],
num_classes: int,
conv_cls: Type[MessagePassing],
act=nn.ReLU(),
**conv_kwargs: Any
):
super().__init__()

self.blocks = nn.ModuleList()
hidden_channels = [num_features] + hidden_channels

nc_output = num_features

for idx in range(len(hidden_channels) - 1):
nc_input = hidden_channels[idx]
nc_output = hidden_channels[idx + 1]
graph_block = GraphBlock(nc_input, nc_output, conv_cls, act, **conv_kwargs)
self.blocks.append(graph_block)

self.lin = Linear(nc_output, num_classes)

def forward(self, data):
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
# 1. Obtain node embeddings
for block in self.blocks:
x = block(x, edge_index, edge_weight)

# 2. Readout layer
x = global_mean_pool(x, data.batch) # [batch_size, hidden_channels]

# 3. Apply a final classifier
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x


class GraphClassifier(ClassificationTask):
"""The ``GraphClassifier`` is a :class:`~flash.Task` for classifying graphs. For more details, see
:ref:`graph_classification`.

Args:
num_features: Number of columns in table (not including target column).
num_classes: Number of classes to classify.
hidden_channels: Hidden dimension sizes.
loss_fn: Loss function for training, defaults to cross entropy.
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `1e-3`
model: GraphNN used, defaults to BaseGraphModel.
conv_cls: kind of convolution used in model, defaults to GCNConv
"""

required_extras = "graph"

def __init__(
self,
num_features: int,
num_classes: int,
hidden_channels: Union[List[int], int] = 512,
loss_fn: Callable = F.cross_entropy,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[Callable, Mapping, Sequence, None] = None,
learning_rate: float = 1e-3,
model: torch.nn.Module = None,
conv_cls: Type[MessagePassing] = GCNConv,
**conv_kwargs
):

self.save_hyperparameters()

if isinstance(hidden_channels, int):
hidden_channels = [hidden_channels]

if not model:
model = BaseGraphModel(num_features, hidden_channels, num_classes, conv_cls, **conv_kwargs)

super().__init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
)

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch, batch.y)
return super().training_step(batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch, batch.y)
return super().validation_step(batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch, batch.y)
return super().test_step(batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
Loading