Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Pytorch Geometric integration (#73)
Browse files Browse the repository at this point in the history
* Initial structure of GraphClassification model.py

* Improvement of model.py. Still need to debug etc

* BasicDataset Implemented

* Create __init__.py

* Implemented dataset and DataModule as for image processing

Lacking Pipeline and it is possible that division in raw and processed folders might be needed.

* Pipeline taken from images.

I'm unsure how to adapt

* Initial structure of GraphClassification model.py

* Improvement of model.py. Still need to debug etc

* BasicDataset Implemented

* Implemented dataset and DataModule as for image processing

Lacking Pipeline and it is possible that division in raw and processed folders might be needed.

* Pipeline taken from images.

I'm unsure how to adapt

* Choice of model implemented (you can pass a model to GraphClassifier)

The class BasicGraphDataset in graphClassification/data.py is probably unneded

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Initial readaptation of the structure

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minimal structure of how to structure data.py files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor corrections

* update

* i

* update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added auto_dataset.num_features

* Deleted manually included num_features so that it is extracted from GraphDatasetSource()

* Test for GraphClassification implemented

* Documentation for GraphClassification included

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Creation of from_pygdatasequence method in DataModule and GraphSequenceDataSource()

* Update graph_classification.py

* Update datatype_graph.txt

* Tests and docs for the from_pygdatasequence method

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Graph requirements

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update CHANGELOG.md

* Update requirements with pytorch geometric libraries

* Simplified, version with only the DataSource

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor tweaks

* Update the flash_example to reflect the new template

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Delete IMDB-BINARY_A.txt

* Delete IMDB-BINARY_graph_indicator.txt

* Delete IMDB-BINARY_graph_labels.txt

* Class method from_pygdatasequence from flash/core/data/data_module.py

* Update docs

* fix imports.py

* remove unused imports

* clean init.py

* updates

* Updates

* Updates

* Updates

* Updates

* Update docs

* Update docs

* Update docs

* fix tests

* fix tests

* Add API reference

* Try fix

* Try fix

* Try fix

* Update flash/core/data/auto_dataset.py

* Update docstring

Co-authored-by: pablo <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: tchaton <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
6 people authored Jul 14, 2021
1 parent 7237a9f commit a340464
Show file tree
Hide file tree
Showing 21 changed files with 609 additions and 5 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ jobs:
python-version: 3.8
requires: 'latest'
topic: ['serve']
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
topic: ['graph']

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
Expand Down Expand Up @@ -109,6 +113,7 @@ jobs:
run: |
python --version
pip --version
pip install torch>=1.8
pip install '.[${{ join(matrix.topic,',') }}]' --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install '.[test]' --pre --upgrade
pip list
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,4 @@ CameraRGB
CameraSeg
jigsaw_toxic_comments
flash_examples/serve/tabular_classification/data
flash_examples/data
5 changes: 2 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for nesting of `Task` objects ([#575](https://github.com/PyTorchLightning/lightning-flash/pull/575))

- Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

- Removed bolts pretrained weights for SSL from ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

### Deprecated


### Fixed


Expand Down
33 changes: 33 additions & 0 deletions docs/source/api/graph.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
###########
flash.graph
###########

.. contents::
:depth: 1
:local:
:backlinks: top

.. currentmodule:: flash.graph

Classification
______________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~classification.model.GraphClassifier
~classification.data.GraphClassificationData

classification.data.GraphClassificationPreprocess

flash.graph.data
________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~data.GraphDatasetDataSource
7 changes: 7 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ Lightning Flash
reference/summarization
reference/translation

.. toctree::
:maxdepth: 1
:caption: Graph

reference/graph_classification

.. toctree::
:maxdepth: 1
:caption: Integrations
Expand All @@ -73,6 +79,7 @@ Lightning Flash
api/tabular
api/text
api/video
api/graph

.. toctree::
:maxdepth: 1
Expand Down
33 changes: 33 additions & 0 deletions docs/source/reference/graph_classification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
.. _graph_classification:

####################
Graph Classification
####################

********
The Task
********
This task consist on classifying graphs.
The task predicts which ‘class’ the graph belongs to.
A class is a label that indicates the kind of graph.
For example, a label may indicate whether one molecule interacts with another.

The :class:`~flash.graph.classification.model.GraphClassifier` and :class:`~flash.graph.classification.data.GraphClassificationData` classes internally rely on `pytorch-geometric <https://github.com/rusty1s/pytorch_geometric>`_.

------

*******
Example
*******

Let's look at the task of classifying graphs from the KKI data set from `TU Dortmund University <https://chrsmrrs.github.io/datasets>`_.

Once we've created the `TUDataset <https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/tu_dataset.html#TUDataset>`, we create the :class:`~flash.graph.classification.data.GraphClassificationData`.
We then create our :class:`~flash.graph.classification.model.GraphClassifier` and train on the KKI data.
Next, we use the trained :class:`~flash.graph.classification.model.GraphClassifier` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/graph_classification.py
:language: python
:lines: 14
5 changes: 5 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def _compare_version(package: str, op, version) -> bool:
_PIL_AVAILABLE = _module_available("PIL")
_ASTEROID_AVAILABLE = _module_available("asteroid")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric")

if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
Expand All @@ -104,6 +107,7 @@ def _compare_version(package: str, op, version) -> bool:
_AUDIO_AVAILABLE = all([
_ASTEROID_AVAILABLE,
])
_GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE

_EXTRAS_AVAILABLE = {
'image': _IMAGE_AVAILABLE,
Expand All @@ -112,6 +116,7 @@ def _compare_version(package: str, op, version) -> bool:
'video': _VIDEO_AVAILABLE,
'serve': _SERVE_AVAILABLE,
'audio': _AUDIO_AVAILABLE,
'graph': _GRAPH_AVAILABLE,
}


Expand Down
1 change: 1 addition & 0 deletions flash/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.graph.classification import GraphClassificationData, GraphClassifier # noqa: F401
2 changes: 2 additions & 0 deletions flash/graph/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.graph.classification.data import GraphClassificationData # noqa: F401
from flash.graph.classification.model import GraphClassifier # noqa: F401
70 changes: 70 additions & 0 deletions flash/graph/classification/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional

from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires_extras
from flash.graph.data import GraphDatasetDataSource

if _GRAPH_AVAILABLE:
from torch_geometric.data.batch import Batch
from torch_geometric.transforms import NormalizeFeatures


class GraphClassificationPreprocess(Preprocess):

@requires_extras("graph")
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
):
super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.DATASET: GraphDatasetDataSource(),
},
default_data_source=DefaultDataSources.DATASET,
)

def get_state_dict(self) -> Dict[str, Any]:
return self.transforms

@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

@staticmethod
def default_transforms() -> Optional[Dict[str, Callable]]:
return {"pre_tensor_transform": NormalizeFeatures(), "collate": Batch.from_data_list}


class GraphClassificationData(DataModule):
"""Data module for graph classification tasks."""

preprocess_cls = GraphClassificationPreprocess

@property
def num_features(self):
n_cls_train = getattr(self.train_dataset, "num_features", None)
n_cls_val = getattr(self.val_dataset, "num_features", None)
n_cls_test = getattr(self.test_dataset, "num_features", None)
return n_cls_train or n_cls_val or n_cls_test
147 changes: 147 additions & 0 deletions flash/graph/classification/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Mapping, Sequence, Type, Union

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Linear

from flash.core.classification import ClassificationTask
from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE

if _TORCH_GEOMETRIC_AVAILABLE:
from torch_geometric.nn import BatchNorm, GCNConv, global_mean_pool, MessagePassing
else:
MessagePassing = None
GCNConv = None


class GraphBlock(nn.Module):

def __init__(self, nc_input, nc_output, conv_cls, act=nn.ReLU(), **conv_kwargs):
super().__init__()
self.conv = conv_cls(nc_input, nc_output, **conv_kwargs)
self.norm = BatchNorm(nc_output)
self.act = act

def forward(self, x, edge_index, edge_weight):
x = self.conv(x, edge_index, edge_weight=edge_weight)
x = self.norm(x)
return self.act(x)


class BaseGraphModel(nn.Module):

def __init__(
self,
num_features: int,
hidden_channels: List[int],
num_classes: int,
conv_cls: Type[MessagePassing],
act=nn.ReLU(),
**conv_kwargs: Any
):
super().__init__()

self.blocks = nn.ModuleList()
hidden_channels = [num_features] + hidden_channels

nc_output = num_features

for idx in range(len(hidden_channels) - 1):
nc_input = hidden_channels[idx]
nc_output = hidden_channels[idx + 1]
graph_block = GraphBlock(nc_input, nc_output, conv_cls, act, **conv_kwargs)
self.blocks.append(graph_block)

self.lin = Linear(nc_output, num_classes)

def forward(self, data):
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
# 1. Obtain node embeddings
for block in self.blocks:
x = block(x, edge_index, edge_weight)

# 2. Readout layer
x = global_mean_pool(x, data.batch) # [batch_size, hidden_channels]

# 3. Apply a final classifier
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x


class GraphClassifier(ClassificationTask):
"""The ``GraphClassifier`` is a :class:`~flash.Task` for classifying graphs. For more details, see
:ref:`graph_classification`.
Args:
num_features: Number of columns in table (not including target column).
num_classes: Number of classes to classify.
hidden_channels: Hidden dimension sizes.
loss_fn: Loss function for training, defaults to cross entropy.
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `1e-3`
model: GraphNN used, defaults to BaseGraphModel.
conv_cls: kind of convolution used in model, defaults to GCNConv
"""

required_extras = "graph"

def __init__(
self,
num_features: int,
num_classes: int,
hidden_channels: Union[List[int], int] = 512,
loss_fn: Callable = F.cross_entropy,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[Callable, Mapping, Sequence, None] = None,
learning_rate: float = 1e-3,
model: torch.nn.Module = None,
conv_cls: Type[MessagePassing] = GCNConv,
**conv_kwargs
):

self.save_hyperparameters()

if isinstance(hidden_channels, int):
hidden_channels = [hidden_channels]

if not model:
model = BaseGraphModel(num_features, hidden_channels, num_classes, conv_cls, **conv_kwargs)

super().__init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
)

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch, batch.y)
return super().training_step(batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch, batch.y)
return super().validation_step(batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch, batch.y)
return super().test_step(batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
Loading

0 comments on commit a340464

Please sign in to comment.