Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Graph embedding and graph backbones (#592)
Browse files Browse the repository at this point in the history
Co-authored-by: tchaton <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
4 people authored Nov 9, 2021
1 parent 8c9903d commit 00ef240
Show file tree
Hide file tree
Showing 22 changed files with 360 additions and 117 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ jobs:
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
pip install torch-sparse -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
pip install torch-cluster -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
- name: Install dependencies
run: |
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added backbones for `GraphClassifier` ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592))

- Added `GraphEmbedder` task ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592))

### Changed

- Changed `Preprocess` to `InputTransform` ([#951](https://github.com/PyTorchLightning/lightning-flash/pull/951))
Expand Down
10 changes: 10 additions & 0 deletions docs/source/api/graph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ ______________

classification.data.GraphClassificationInputTransform

Embedding
_________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~embedding.model.GraphEmbedder

flash.graph.data
________________

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Lightning Flash
:caption: Graph

reference/graph_classification
reference/graph_embedder

.. toctree::
:maxdepth: 1
Expand Down
28 changes: 28 additions & 0 deletions docs/source/reference/graph_embedder.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
.. _graph_embedder:

##############
Graph Embedder
##############

********
The Task
********
This task consists of creating an embedding of a graph. That is, a vector of features which can be used for a downstream task.
The :class:`~flash.graph.classification.model.GraphEmbedder` and :class:`~flash.graph.classification.data.GraphClassificationData` classes internally rely on `pytorch-geometric <https://github.com/rusty1s/pytorch_geometric>`_.

------

*******
Example
*******

Let's look at generating embeddings of graphs from the KKI data set from `TU Dortmund University <https://chrsmrrs.github.io/datasets>`_.

We start by creating the `TUDataset <https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/tu_dataset.html#TUDataset>`.
Next, we load a trained :class:`~flash.graph.classification.model.GraphEmbedder` (from a previously trained :class:`~flash.graph.classification.model.GraphClassifier`).
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/graph_embedder.py
:language: python
:lines: 14
5 changes: 4 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric")
_NETWORKX_AVAILABLE = _module_available("networkx")
_TORCHAUDIO_AVAILABLE = _module_available("torchaudio")
_SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece")
_DATASETS_AVAILABLE = _module_available("datasets")
Expand Down Expand Up @@ -143,7 +144,9 @@ class Image:
_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE
_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE
_AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _LIBROSA_AVAILABLE, _TRANSFORMERS_AVAILABLE])
_GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE
_GRAPH_AVAILABLE = (
_TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE and _NETWORKX_AVAILABLE
)

_EXTRAS_AVAILABLE = {
"image": _IMAGE_AVAILABLE,
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ def __str__(self):
_PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo")
_VISSL = Provider("Facebook Research/vissl", "https://github.com/facebookresearch/vissl")
_PYTORCH_FORECASTING = Provider("jdb78/PyTorch-Forecasting", "https://github.com/jdb78/pytorch-forecasting")
_PYTORCH_GEOMETRIC = Provider("PyG/PyTorch Geometric", "https://github.com/pyg-team/pytorch_geometric")
1 change: 1 addition & 0 deletions flash/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from flash.graph.classification import GraphClassificationData, GraphClassifier # noqa: F401
from flash.graph.embedding import GraphEmbedder # noqa: F401
41 changes: 41 additions & 0 deletions flash/graph/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _GRAPH_AVAILABLE
from flash.core.utilities.providers import _PYTORCH_GEOMETRIC

if _GRAPH_AVAILABLE:
from torch_geometric.nn.models import GAT, GCN, GIN, GraphSAGE

MODELS = {"GCN": GCN, "GraphSAGE": GraphSAGE, "GAT": GAT, "GIN": GIN}
else:
MODELS = {}

GRAPH_BACKBONES = FlashRegistry("backbones")


def _load_graph_backbone(
model_name: str,
in_channels: int,
hidden_channels: int = 512,
num_layers: int = 4,
):
model = MODELS[model_name]
return model(in_channels, hidden_channels, num_layers)


for model_name in MODELS.keys():
GRAPH_BACKBONES(name=model_name, providers=_PYTORCH_GEOMETRIC)(partial(_load_graph_backbone, model_name))
167 changes: 61 additions & 106 deletions flash/graph/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,157 +11,90 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Optional, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
from torch import nn, Tensor
from torch import nn
from torch.nn import functional as F
from torch.nn import Linear

from flash.core.classification import ClassificationTask
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _GRAPH_AVAILABLE
from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE
from flash.graph.backbones import GRAPH_BACKBONES

if _GRAPH_AVAILABLE:
from torch_geometric.nn import BatchNorm, GCNConv, global_mean_pool, MessagePassing
else:
MessagePassing = object
GCNConv = object


class GraphBlock(nn.Module):
"""Graph convolutional block.
Args:
nc_input: number of input channels
nc_output: number of output channels
conv_cls: graph convolutional class to use
act: activation function to use
**conv_kwargs: additional kwargs used for initialization of convolutional operator
"""

def __init__(
self,
nc_input: int,
nc_output: int,
conv_cls: nn.Module,
act: Union[Callable, nn.Module] = nn.ReLU(),
**conv_kwargs
):
super().__init__()
self.conv = conv_cls(nc_input, nc_output, **conv_kwargs)
self.norm = BatchNorm(nc_output)
self.act = act
from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool

def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None) -> Tensor:
x = self.conv(x, edge_index, edge_weight=edge_weight)
x = self.norm(x)
return self.act(x)


class BaseGraphModel(nn.Module):
"""Base convolutional graph model.
Args:
num_features: number of input features
hidden_channels: list of integers with the number of channels in all the hidden layers.
The length of the list determines the depth of the network.
num_classes: integer determining the number of classes
conv_cls: graph convolutional class to use as building blocks
act: activation function to use between layers
**conv_kwargs: additional kwargs used for initialization of convolutional operator
"""

def __init__(
self,
num_features: int,
hidden_channels: List[int],
num_classes: int,
conv_cls: Type[MessagePassing],
act: Union[Callable, nn.Module] = nn.ReLU(),
**conv_kwargs: Any
):
super().__init__()

self.blocks = nn.ModuleList()
hidden_channels = [num_features] + hidden_channels

nc_output = num_features

for idx in range(len(hidden_channels) - 1):
nc_input = hidden_channels[idx]
nc_output = hidden_channels[idx + 1]
graph_block = GraphBlock(nc_input, nc_output, conv_cls, act, **conv_kwargs)
self.blocks.append(graph_block)

self.lin = Linear(nc_output, num_classes)

def forward(self, data: Any) -> Tensor:
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
# 1. Obtain node embeddings
for block in self.blocks:
x = block(x, edge_index, edge_weight)

# 2. Readout layer
x = global_mean_pool(x, data.batch) # [batch_size, hidden_channels]

# 3. Apply a final classifier
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x
POOLING_FUNCTIONS = {"mean": global_mean_pool, "add": global_add_pool, "max": global_max_pool}
else:
POOLING_FUNCTIONS = {}


class GraphClassifier(ClassificationTask):
"""The ``GraphClassifier`` is a :class:`~flash.Task` for classifying graphs. For more details, see
:ref:`graph_classification`.
Args:
num_features: Number of columns in table (not including target column).
num_classes: Number of classes to classify.
hidden_channels: Hidden dimension sizes.
learning_rate: Learning rate to use for training, defaults to `1e-3`
num_features (int): The number of features in the input.
num_classes (int): Number of classes to classify.
backbone: Name of the backbone to use.
backbone_kwargs: Dictionary dependent on the backbone, containing for example in_channels, out_channels,
hidden_channels or depth (number of layers).
pooling_fn: The global pooling operation to use (one of: "max", "max", "add" or a callable).
head: The head to use.
loss_fn: Loss function for training, defaults to cross entropy.
learning_rate: Learning rate to use for training.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
metrics: Metrics to compute for training and evaluation.
model: GraphNN used, defaults to BaseGraphModel.
conv_cls: kind of convolution used in model, defaults to GCNConv
**conv_kwargs: additional kwargs used for initialization of convolutional operator
"""

required_extras = "graph"
backbones: FlashRegistry = GRAPH_BACKBONES

required_extras: str = "graph"

def __init__(
self,
num_features: int,
num_classes: int,
hidden_channels: Union[List[int], int] = 512,
model: torch.nn.Module = None,
backbone: Union[str, Tuple[nn.Module, int]] = "GCN",
backbone_kwargs: Optional[Dict] = {},
pooling_fn: Optional[Union[str, Callable]] = "mean",
head: Optional[Union[Callable, nn.Module]] = None,
loss_fn: LOSS_FN_TYPE = F.cross_entropy,
learning_rate: float = 1e-3,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
metrics: METRICS_TYPE = None,
conv_cls: Type[MessagePassing] = GCNConv,
**conv_kwargs
):

self.save_hyperparameters()

if isinstance(hidden_channels, int):
hidden_channels = [hidden_channels]

if not model:
model = BaseGraphModel(num_features, hidden_channels, num_classes, conv_cls, **conv_kwargs)

super().__init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
metrics=metrics,
learning_rate=learning_rate,
)

self.save_hyperparameters()

if isinstance(backbone, tuple):
self.backbone, num_out_features = backbone
else:
self.backbone = self.backbones.get(backbone)(in_channels=num_features, **backbone_kwargs)
num_out_features = self.backbone.hidden_channels

self.pooling_fn = POOLING_FUNCTIONS[pooling_fn] if isinstance(pooling_fn, str) else pooling_fn

if head is not None:
self.head = head
else:
self.head = DefaultGraphHead(num_out_features, num_classes)

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch, batch.y)
return super().training_step(batch, batch_idx)
Expand All @@ -176,3 +109,25 @@ def test_step(self, batch: Any, batch_idx: int) -> Any:

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)

def forward(self, data) -> torch.Tensor:
x = self.backbone(data.x, data.edge_index)
x = self.pooling_fn(x, data.batch)
return self.head(x)


class DefaultGraphHead(torch.nn.Module):
def __init__(self, hidden_channels, num_classes, dropout=0.5):
super().__init__()
self.lin1 = Linear(hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, num_classes)
self.dropout = dropout

def reset_parameters(self):
self.lin1.reset_parameters()
self.lin2.reset_parameters()

def forward(self, x):
x = F.relu(self.lin1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
return self.lin2(x)
2 changes: 2 additions & 0 deletions flash/graph/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.graph.classification.data import GraphClassificationData # noqa: F401
from flash.graph.embedding.model import GraphEmbedder # noqa: F401
Loading

0 comments on commit 00ef240

Please sign in to comment.