This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial structure of GraphClassification model.py * Improvement of model.py. Still need to debug etc * BasicDataset Implemented * Create __init__.py * Implemented dataset and DataModule as for image processing Lacking Pipeline and it is possible that division in raw and processed folders might be needed. * Pipeline taken from images. I'm unsure how to adapt * Initial structure of GraphClassification model.py * Improvement of model.py. Still need to debug etc * BasicDataset Implemented * Implemented dataset and DataModule as for image processing Lacking Pipeline and it is possible that division in raw and processed folders might be needed. * Pipeline taken from images. I'm unsure how to adapt * Choice of model implemented (you can pass a model to GraphClassifier) The class BasicGraphDataset in graphClassification/data.py is probably unneded * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Initial readaptation of the structure * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minimal structure of how to structure data.py files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor corrections * update * i * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added auto_dataset.num_features * Deleted manually included num_features so that it is extracted from GraphDatasetSource() * Test for GraphClassification implemented * Documentation for GraphClassification included * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Creation of from_pygdatasequence method in DataModule and GraphSequenceDataSource() * Update graph_classification.py * Update datatype_graph.txt * Tests and docs for the from_pygdatasequence method * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Graph requirements * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update CHANGELOG.md * Update requirements with pytorch geometric libraries * Simplified, version with only the DataSource * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor tweaks * Update the flash_example to reflect the new template * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Delete IMDB-BINARY_A.txt * Delete IMDB-BINARY_graph_indicator.txt * Delete IMDB-BINARY_graph_labels.txt * Class method from_pygdatasequence from flash/core/data/data_module.py * Update docs * fix imports.py * remove unused imports * clean init.py * updates * Updates * Updates * Updates * Updates * Update docs * Update docs * Update docs * fix tests * fix tests * Add API reference * Try fix * Try fix * Try fix * Update flash/core/data/auto_dataset.py * Update docstring Co-authored-by: pablo <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: tchaton <[email protected]> Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: Ethan Harris <[email protected]>
- Loading branch information
1 parent
7237a9f
commit a340464
Showing
21 changed files
with
609 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
########### | ||
flash.graph | ||
########### | ||
|
||
.. contents:: | ||
:depth: 1 | ||
:local: | ||
:backlinks: top | ||
|
||
.. currentmodule:: flash.graph | ||
|
||
Classification | ||
______________ | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:nosignatures: | ||
:template: classtemplate.rst | ||
|
||
~classification.model.GraphClassifier | ||
~classification.data.GraphClassificationData | ||
|
||
classification.data.GraphClassificationPreprocess | ||
|
||
flash.graph.data | ||
________________ | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:nosignatures: | ||
:template: classtemplate.rst | ||
|
||
~data.GraphDatasetDataSource |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
.. _graph_classification: | ||
|
||
#################### | ||
Graph Classification | ||
#################### | ||
|
||
******** | ||
The Task | ||
******** | ||
This task consist on classifying graphs. | ||
The task predicts which ‘class’ the graph belongs to. | ||
A class is a label that indicates the kind of graph. | ||
For example, a label may indicate whether one molecule interacts with another. | ||
|
||
The :class:`~flash.graph.classification.model.GraphClassifier` and :class:`~flash.graph.classification.data.GraphClassificationData` classes internally rely on `pytorch-geometric <https://github.com/rusty1s/pytorch_geometric>`_. | ||
|
||
------ | ||
|
||
******* | ||
Example | ||
******* | ||
|
||
Let's look at the task of classifying graphs from the KKI data set from `TU Dortmund University <https://chrsmrrs.github.io/datasets>`_. | ||
|
||
Once we've created the `TUDataset <https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/tu_dataset.html#TUDataset>`, we create the :class:`~flash.graph.classification.data.GraphClassificationData`. | ||
We then create our :class:`~flash.graph.classification.model.GraphClassifier` and train on the KKI data. | ||
Next, we use the trained :class:`~flash.graph.classification.model.GraphClassifier` for inference. | ||
Finally, we save the model. | ||
Here's the full example: | ||
|
||
.. literalinclude:: ../../../flash_examples/graph_classification.py | ||
:language: python | ||
:lines: 14 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from flash.graph.classification import GraphClassificationData, GraphClassifier # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from flash.graph.classification.data import GraphClassificationData # noqa: F401 | ||
from flash.graph.classification.model import GraphClassifier # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
from flash.core.data.data_module import DataModule | ||
from flash.core.data.data_source import DefaultDataSources | ||
from flash.core.data.process import Preprocess | ||
from flash.core.utilities.imports import _GRAPH_AVAILABLE, requires_extras | ||
from flash.graph.data import GraphDatasetDataSource | ||
|
||
if _GRAPH_AVAILABLE: | ||
from torch_geometric.data.batch import Batch | ||
from torch_geometric.transforms import NormalizeFeatures | ||
|
||
|
||
class GraphClassificationPreprocess(Preprocess): | ||
|
||
@requires_extras("graph") | ||
def __init__( | ||
self, | ||
train_transform: Optional[Dict[str, Callable]] = None, | ||
val_transform: Optional[Dict[str, Callable]] = None, | ||
test_transform: Optional[Dict[str, Callable]] = None, | ||
predict_transform: Optional[Dict[str, Callable]] = None, | ||
): | ||
super().__init__( | ||
train_transform=train_transform, | ||
val_transform=val_transform, | ||
test_transform=test_transform, | ||
predict_transform=predict_transform, | ||
data_sources={ | ||
DefaultDataSources.DATASET: GraphDatasetDataSource(), | ||
}, | ||
default_data_source=DefaultDataSources.DATASET, | ||
) | ||
|
||
def get_state_dict(self) -> Dict[str, Any]: | ||
return self.transforms | ||
|
||
@classmethod | ||
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): | ||
return cls(**state_dict) | ||
|
||
@staticmethod | ||
def default_transforms() -> Optional[Dict[str, Callable]]: | ||
return {"pre_tensor_transform": NormalizeFeatures(), "collate": Batch.from_data_list} | ||
|
||
|
||
class GraphClassificationData(DataModule): | ||
"""Data module for graph classification tasks.""" | ||
|
||
preprocess_cls = GraphClassificationPreprocess | ||
|
||
@property | ||
def num_features(self): | ||
n_cls_train = getattr(self.train_dataset, "num_features", None) | ||
n_cls_val = getattr(self.val_dataset, "num_features", None) | ||
n_cls_test = getattr(self.test_dataset, "num_features", None) | ||
return n_cls_train or n_cls_val or n_cls_test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, List, Mapping, Sequence, Type, Union | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
from torch.nn import Linear | ||
|
||
from flash.core.classification import ClassificationTask | ||
from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE | ||
|
||
if _TORCH_GEOMETRIC_AVAILABLE: | ||
from torch_geometric.nn import BatchNorm, GCNConv, global_mean_pool, MessagePassing | ||
else: | ||
MessagePassing = None | ||
GCNConv = None | ||
|
||
|
||
class GraphBlock(nn.Module): | ||
|
||
def __init__(self, nc_input, nc_output, conv_cls, act=nn.ReLU(), **conv_kwargs): | ||
super().__init__() | ||
self.conv = conv_cls(nc_input, nc_output, **conv_kwargs) | ||
self.norm = BatchNorm(nc_output) | ||
self.act = act | ||
|
||
def forward(self, x, edge_index, edge_weight): | ||
x = self.conv(x, edge_index, edge_weight=edge_weight) | ||
x = self.norm(x) | ||
return self.act(x) | ||
|
||
|
||
class BaseGraphModel(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
num_features: int, | ||
hidden_channels: List[int], | ||
num_classes: int, | ||
conv_cls: Type[MessagePassing], | ||
act=nn.ReLU(), | ||
**conv_kwargs: Any | ||
): | ||
super().__init__() | ||
|
||
self.blocks = nn.ModuleList() | ||
hidden_channels = [num_features] + hidden_channels | ||
|
||
nc_output = num_features | ||
|
||
for idx in range(len(hidden_channels) - 1): | ||
nc_input = hidden_channels[idx] | ||
nc_output = hidden_channels[idx + 1] | ||
graph_block = GraphBlock(nc_input, nc_output, conv_cls, act, **conv_kwargs) | ||
self.blocks.append(graph_block) | ||
|
||
self.lin = Linear(nc_output, num_classes) | ||
|
||
def forward(self, data): | ||
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr | ||
# 1. Obtain node embeddings | ||
for block in self.blocks: | ||
x = block(x, edge_index, edge_weight) | ||
|
||
# 2. Readout layer | ||
x = global_mean_pool(x, data.batch) # [batch_size, hidden_channels] | ||
|
||
# 3. Apply a final classifier | ||
x = F.dropout(x, p=0.5, training=self.training) | ||
x = self.lin(x) | ||
return x | ||
|
||
|
||
class GraphClassifier(ClassificationTask): | ||
"""The ``GraphClassifier`` is a :class:`~flash.Task` for classifying graphs. For more details, see | ||
:ref:`graph_classification`. | ||
Args: | ||
num_features: Number of columns in table (not including target column). | ||
num_classes: Number of classes to classify. | ||
hidden_channels: Hidden dimension sizes. | ||
loss_fn: Loss function for training, defaults to cross entropy. | ||
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. | ||
metrics: Metrics to compute for training and evaluation. | ||
learning_rate: Learning rate to use for training, defaults to `1e-3` | ||
model: GraphNN used, defaults to BaseGraphModel. | ||
conv_cls: kind of convolution used in model, defaults to GCNConv | ||
""" | ||
|
||
required_extras = "graph" | ||
|
||
def __init__( | ||
self, | ||
num_features: int, | ||
num_classes: int, | ||
hidden_channels: Union[List[int], int] = 512, | ||
loss_fn: Callable = F.cross_entropy, | ||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, | ||
metrics: Union[Callable, Mapping, Sequence, None] = None, | ||
learning_rate: float = 1e-3, | ||
model: torch.nn.Module = None, | ||
conv_cls: Type[MessagePassing] = GCNConv, | ||
**conv_kwargs | ||
): | ||
|
||
self.save_hyperparameters() | ||
|
||
if isinstance(hidden_channels, int): | ||
hidden_channels = [hidden_channels] | ||
|
||
if not model: | ||
model = BaseGraphModel(num_features, hidden_channels, num_classes, conv_cls, **conv_kwargs) | ||
|
||
super().__init__( | ||
model=model, | ||
loss_fn=loss_fn, | ||
optimizer=optimizer, | ||
metrics=metrics, | ||
learning_rate=learning_rate, | ||
) | ||
|
||
def training_step(self, batch: Any, batch_idx: int) -> Any: | ||
batch = (batch, batch.y) | ||
return super().training_step(batch, batch_idx) | ||
|
||
def validation_step(self, batch: Any, batch_idx: int) -> Any: | ||
batch = (batch, batch.y) | ||
return super().validation_step(batch, batch_idx) | ||
|
||
def test_step(self, batch: Any, batch_idx: int) -> Any: | ||
batch = (batch, batch.y) | ||
return super().test_step(batch, batch_idx) | ||
|
||
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: | ||
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) |
Oops, something went wrong.