diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 606dd9b7d2..e9e4bfea9c 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -145,6 +145,7 @@ jobs: pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install torch-scatter -f https://data.pyg.org/whl/torch-1.9.0+cpu.html pip install torch-sparse -f https://data.pyg.org/whl/torch-1.9.0+cpu.html + pip install torch-cluster -f https://data.pyg.org/whl/torch-1.9.0+cpu.html - name: Install dependencies run: | diff --git a/CHANGELOG.md b/CHANGELOG.md index b733661ba5..1abf240d9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added backbones for `GraphClassifier` ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592)) + +- Added `GraphEmbedder` task ([#592](https://github.com/PyTorchLightning/lightning-flash/pull/592)) + ### Changed - Changed `Preprocess` to `InputTransform` ([#951](https://github.com/PyTorchLightning/lightning-flash/pull/951)) diff --git a/docs/source/api/graph.rst b/docs/source/api/graph.rst index 74becb27df..65a437cdf4 100644 --- a/docs/source/api/graph.rst +++ b/docs/source/api/graph.rst @@ -22,6 +22,16 @@ ______________ classification.data.GraphClassificationInputTransform +Embedding +_________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~embedding.model.GraphEmbedder + flash.graph.data ________________ diff --git a/docs/source/index.rst b/docs/source/index.rst index 5f9a5f9114..b23d4c0640 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -85,6 +85,7 @@ Lightning Flash :caption: Graph reference/graph_classification + reference/graph_embedder .. toctree:: :maxdepth: 1 diff --git a/docs/source/reference/graph_embedder.rst b/docs/source/reference/graph_embedder.rst new file mode 100644 index 0000000000..4f14dc2db6 --- /dev/null +++ b/docs/source/reference/graph_embedder.rst @@ -0,0 +1,28 @@ +.. _graph_embedder: + +############## +Graph Embedder +############## + +******** +The Task +******** +This task consists of creating an embedding of a graph. That is, a vector of features which can be used for a downstream task. +The :class:`~flash.graph.classification.model.GraphEmbedder` and :class:`~flash.graph.classification.data.GraphClassificationData` classes internally rely on `pytorch-geometric `_. + +------ + +******* +Example +******* + +Let's look at generating embeddings of graphs from the KKI data set from `TU Dortmund University `_. + +We start by creating the `TUDataset `. +Next, we load a trained :class:`~flash.graph.classification.model.GraphEmbedder` (from a previously trained :class:`~flash.graph.classification.model.GraphClassifier`). +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/graph_embedder.py + :language: python + :lines: 14 diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index cd50f8cb8c..41ade27623 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -93,6 +93,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter") _TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse") _TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric") +_NETWORKX_AVAILABLE = _module_available("networkx") _TORCHAUDIO_AVAILABLE = _module_available("torchaudio") _SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece") _DATASETS_AVAILABLE = _module_available("datasets") @@ -143,7 +144,9 @@ class Image: _SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE _POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE and _TORCHVISION_AVAILABLE _AUDIO_AVAILABLE = all([_TORCHAUDIO_AVAILABLE, _LIBROSA_AVAILABLE, _TRANSFORMERS_AVAILABLE]) -_GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE +_GRAPH_AVAILABLE = ( + _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE and _NETWORKX_AVAILABLE +) _EXTRAS_AVAILABLE = { "image": _IMAGE_AVAILABLE, diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index d536154d73..4c2af721a9 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -47,3 +47,4 @@ def __str__(self): _PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo") _VISSL = Provider("Facebook Research/vissl", "https://github.com/facebookresearch/vissl") _PYTORCH_FORECASTING = Provider("jdb78/PyTorch-Forecasting", "https://github.com/jdb78/pytorch-forecasting") +_PYTORCH_GEOMETRIC = Provider("PyG/PyTorch Geometric", "https://github.com/pyg-team/pytorch_geometric") diff --git a/flash/graph/__init__.py b/flash/graph/__init__.py index cb30102379..64cfec6a12 100644 --- a/flash/graph/__init__.py +++ b/flash/graph/__init__.py @@ -1 +1,2 @@ from flash.graph.classification import GraphClassificationData, GraphClassifier # noqa: F401 +from flash.graph.embedding import GraphEmbedder # noqa: F401 diff --git a/flash/graph/backbones.py b/flash/graph/backbones.py new file mode 100644 index 0000000000..d09262c569 --- /dev/null +++ b/flash/graph/backbones.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _GRAPH_AVAILABLE +from flash.core.utilities.providers import _PYTORCH_GEOMETRIC + +if _GRAPH_AVAILABLE: + from torch_geometric.nn.models import GAT, GCN, GIN, GraphSAGE + + MODELS = {"GCN": GCN, "GraphSAGE": GraphSAGE, "GAT": GAT, "GIN": GIN} +else: + MODELS = {} + +GRAPH_BACKBONES = FlashRegistry("backbones") + + +def _load_graph_backbone( + model_name: str, + in_channels: int, + hidden_channels: int = 512, + num_layers: int = 4, +): + model = MODELS[model_name] + return model(in_channels, hidden_channels, num_layers) + + +for model_name in MODELS.keys(): + GRAPH_BACKBONES(name=model_name, providers=_PYTORCH_GEOMETRIC)(partial(_load_graph_backbone, model_name)) diff --git a/flash/graph/classification/model.py b/flash/graph/classification/model.py index d1069f5121..254fd5366a 100644 --- a/flash/graph/classification/model.py +++ b/flash/graph/classification/model.py @@ -11,104 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List, Optional, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch -from torch import nn, Tensor +from torch import nn from torch.nn import functional as F from torch.nn import Linear from flash.core.classification import ClassificationTask +from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _GRAPH_AVAILABLE from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE +from flash.graph.backbones import GRAPH_BACKBONES if _GRAPH_AVAILABLE: - from torch_geometric.nn import BatchNorm, GCNConv, global_mean_pool, MessagePassing -else: - MessagePassing = object - GCNConv = object - - -class GraphBlock(nn.Module): - """Graph convolutional block. - - Args: - nc_input: number of input channels - nc_output: number of output channels - conv_cls: graph convolutional class to use - act: activation function to use - **conv_kwargs: additional kwargs used for initialization of convolutional operator - """ - - def __init__( - self, - nc_input: int, - nc_output: int, - conv_cls: nn.Module, - act: Union[Callable, nn.Module] = nn.ReLU(), - **conv_kwargs - ): - super().__init__() - self.conv = conv_cls(nc_input, nc_output, **conv_kwargs) - self.norm = BatchNorm(nc_output) - self.act = act + from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool - def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None) -> Tensor: - x = self.conv(x, edge_index, edge_weight=edge_weight) - x = self.norm(x) - return self.act(x) - - -class BaseGraphModel(nn.Module): - """Base convolutional graph model. - - Args: - num_features: number of input features - hidden_channels: list of integers with the number of channels in all the hidden layers. - The length of the list determines the depth of the network. - num_classes: integer determining the number of classes - conv_cls: graph convolutional class to use as building blocks - act: activation function to use between layers - **conv_kwargs: additional kwargs used for initialization of convolutional operator - """ - - def __init__( - self, - num_features: int, - hidden_channels: List[int], - num_classes: int, - conv_cls: Type[MessagePassing], - act: Union[Callable, nn.Module] = nn.ReLU(), - **conv_kwargs: Any - ): - super().__init__() - - self.blocks = nn.ModuleList() - hidden_channels = [num_features] + hidden_channels - - nc_output = num_features - - for idx in range(len(hidden_channels) - 1): - nc_input = hidden_channels[idx] - nc_output = hidden_channels[idx + 1] - graph_block = GraphBlock(nc_input, nc_output, conv_cls, act, **conv_kwargs) - self.blocks.append(graph_block) - - self.lin = Linear(nc_output, num_classes) - - def forward(self, data: Any) -> Tensor: - x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr - # 1. Obtain node embeddings - for block in self.blocks: - x = block(x, edge_index, edge_weight) - - # 2. Readout layer - x = global_mean_pool(x, data.batch) # [batch_size, hidden_channels] - - # 3. Apply a final classifier - x = F.dropout(x, p=0.5, training=self.training) - x = self.lin(x) - return x + POOLING_FUNCTIONS = {"mean": global_mean_pool, "add": global_add_pool, "max": global_max_pool} +else: + POOLING_FUNCTIONS = {} class GraphClassifier(ClassificationTask): @@ -116,45 +37,42 @@ class GraphClassifier(ClassificationTask): :ref:`graph_classification`. Args: - num_features: Number of columns in table (not including target column). - num_classes: Number of classes to classify. - hidden_channels: Hidden dimension sizes. - learning_rate: Learning rate to use for training, defaults to `1e-3` + num_features (int): The number of features in the input. + num_classes (int): Number of classes to classify. + backbone: Name of the backbone to use. + backbone_kwargs: Dictionary dependent on the backbone, containing for example in_channels, out_channels, + hidden_channels or depth (number of layers). + pooling_fn: The global pooling operation to use (one of: "max", "max", "add" or a callable). + head: The head to use. + loss_fn: Loss function for training, defaults to cross entropy. + learning_rate: Learning rate to use for training. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. - model: GraphNN used, defaults to BaseGraphModel. - conv_cls: kind of convolution used in model, defaults to GCNConv - **conv_kwargs: additional kwargs used for initialization of convolutional operator """ - required_extras = "graph" + backbones: FlashRegistry = GRAPH_BACKBONES + + required_extras: str = "graph" def __init__( self, num_features: int, num_classes: int, - hidden_channels: Union[List[int], int] = 512, - model: torch.nn.Module = None, + backbone: Union[str, Tuple[nn.Module, int]] = "GCN", + backbone_kwargs: Optional[Dict] = {}, + pooling_fn: Optional[Union[str, Callable]] = "mean", + head: Optional[Union[Callable, nn.Module]] = None, loss_fn: LOSS_FN_TYPE = F.cross_entropy, learning_rate: float = 1e-3, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, - conv_cls: Type[MessagePassing] = GCNConv, - **conv_kwargs ): self.save_hyperparameters() - if isinstance(hidden_channels, int): - hidden_channels = [hidden_channels] - - if not model: - model = BaseGraphModel(num_features, hidden_channels, num_classes, conv_cls, **conv_kwargs) - super().__init__( - model=model, loss_fn=loss_fn, optimizer=optimizer, lr_scheduler=lr_scheduler, @@ -162,6 +80,21 @@ def __init__( learning_rate=learning_rate, ) + self.save_hyperparameters() + + if isinstance(backbone, tuple): + self.backbone, num_out_features = backbone + else: + self.backbone = self.backbones.get(backbone)(in_channels=num_features, **backbone_kwargs) + num_out_features = self.backbone.hidden_channels + + self.pooling_fn = POOLING_FUNCTIONS[pooling_fn] if isinstance(pooling_fn, str) else pooling_fn + + if head is not None: + self.head = head + else: + self.head = DefaultGraphHead(num_out_features, num_classes) + def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch, batch.y) return super().training_step(batch, batch_idx) @@ -176,3 +109,25 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + + def forward(self, data) -> torch.Tensor: + x = self.backbone(data.x, data.edge_index) + x = self.pooling_fn(x, data.batch) + return self.head(x) + + +class DefaultGraphHead(torch.nn.Module): + def __init__(self, hidden_channels, num_classes, dropout=0.5): + super().__init__() + self.lin1 = Linear(hidden_channels, hidden_channels) + self.lin2 = Linear(hidden_channels, num_classes) + self.dropout = dropout + + def reset_parameters(self): + self.lin1.reset_parameters() + self.lin2.reset_parameters() + + def forward(self, x): + x = F.relu(self.lin1(x)) + x = F.dropout(x, p=self.dropout, training=self.training) + return self.lin2(x) diff --git a/flash/graph/embedding/__init__.py b/flash/graph/embedding/__init__.py new file mode 100644 index 0000000000..9cd978422c --- /dev/null +++ b/flash/graph/embedding/__init__.py @@ -0,0 +1,2 @@ +from flash.graph.classification.data import GraphClassificationData # noqa: F401 +from flash.graph.embedding.model import GraphEmbedder # noqa: F401 diff --git a/flash/graph/embedding/model.py b/flash/graph/embedding/model.py new file mode 100644 index 0000000000..8624fcb0d0 --- /dev/null +++ b/flash/graph/embedding/model.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, IO, Optional, Union + +import torch +from torch import nn + +from flash.core.model import Task +from flash.graph.classification.data import GraphClassificationInputTransform +from flash.graph.classification.model import GraphClassifier, POOLING_FUNCTIONS + + +class GraphEmbedder(Task): + """The ``GraphEmbedder`` is a :class:`~flash.Task` for obtaining feature vectors (embeddings) from graphs. For + more details, see :ref:`graph_embedder`. + + Args: + backbone: A model to use to extract image features. + pooling_fn: The global pooling operation to use (one of: "max", "max", "add" or a callable). + """ + + required_extras: str = "graph" + + def __init__(self, backbone: nn.Module, pooling_fn: Optional[Union[str, Callable]] = "mean"): + super().__init__(model=None, input_transform=GraphClassificationInputTransform()) + + # Don't save backbone or pooling_fn if it is not a string + self.save_hyperparameters(ignore=["backbone"] if isinstance(pooling_fn, str) else ["backbone", "pooling_fn"]) + + self.backbone = backbone + + self.pooling_fn = POOLING_FUNCTIONS[pooling_fn] if isinstance(pooling_fn, str) else pooling_fn + + def forward(self, data) -> torch.Tensor: + x = self.backbone(data.x, data.edge_index) + x = self.pooling_fn(x, data.batch) + return x + + def training_step(self, batch: Any, batch_idx: int) -> Any: + raise NotImplementedError("Training a `GraphEmbedder` is not supported. Use a `GraphClassifier` instead.") + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + raise NotImplementedError("Validating a `GraphEmbedder` is not supported. Use a `GraphClassifier` instead.") + + def test_step(self, batch: Any, batch_idx: int) -> Any: + raise NotImplementedError("Testing a `GraphEmbedder` is not supported. Use a `GraphClassifier` instead.") + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path: Union[str, IO], + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + hparams_file: Optional[str] = None, + strict: bool = True, + **kwargs, + ) -> "GraphEmbedder": + classifier = GraphClassifier.load_from_checkpoint( + checkpoint_path, + map_location=map_location, + hparams_file=hparams_file, + strict=strict, + **kwargs, + ) + + return cls(classifier.backbone, classifier.pooling_fn) diff --git a/flash_examples/graph_classification.py b/flash_examples/graph_classification.py index 4519f70c33..a975bcbe4b 100644 --- a/flash_examples/graph_classification.py +++ b/flash_examples/graph_classification.py @@ -28,9 +28,11 @@ train_dataset=dataset, val_split=0.1, ) - # 2. Build the task -model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes) +backbone_kwargs = {"hidden_channels": 512, "num_layers": 4} +model = GraphClassifier( + num_features=datamodule.num_features, num_classes=datamodule.num_classes, backbone_kwargs=backbone_kwargs +) # 3. Create the trainer and fit the model trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) @@ -41,4 +43,4 @@ print(predictions) # 5. Save the model! -trainer.save_checkpoint("graph_classification.pt") +trainer.save_checkpoint("graph_classification_model.pt") diff --git a/flash_examples/graph_embedder.py b/flash_examples/graph_embedder.py new file mode 100644 index 0000000000..7646b0f5c8 --- /dev/null +++ b/flash_examples/graph_embedder.py @@ -0,0 +1,29 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.utilities.imports import example_requires +from flash.graph import GraphEmbedder + +example_requires("graph") + +from torch_geometric.datasets import TUDataset # noqa: E402 + +# 1. Create the DataModule +dataset = TUDataset(root="data", name="KKI") + +# 2. Load a previously trained GraphClassifier +model = GraphEmbedder.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.6.0/graph_classification_model.pt") + +# 3. Generate embeddings for the first 3 graphs +predictions = model.predict(dataset[:3]) +print(predictions) diff --git a/requirements.txt b/requirements.txt index 5e391bf5be..63cdf3da01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ packaging numpy torch>=1.7.1 -torchmetrics>=0.5.1 +torchmetrics>=0.5.0,!=0.5.1 pytorch-lightning>=1.4.0 pyDeprecate pandas<1.3.0 diff --git a/requirements/datatype_graph.txt b/requirements/datatype_graph.txt index 9109e2167f..79fa3fc6c0 100644 --- a/requirements/datatype_graph.txt +++ b/requirements/datatype_graph.txt @@ -1,3 +1,5 @@ torch-scatter torch-sparse -torch-geometric +torch-geometric>=2.0.0 +torch-cluster +networkx diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 722f0f5d32..3f91113201 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -38,10 +38,11 @@ from flash.core.data.io.input_transform import DefaultInputTransform from flash.core.data.io.output_transform import OutputTransform from flash.core.utilities.imports import _TORCH_OPTIMIZER_AVAILABLE, _TRANSFORMERS_AVAILABLE, Image +from flash.graph import GraphClassifier, GraphEmbedder from flash.image import ImageClassificationData, ImageClassifier, SemanticSegmentation from flash.tabular import TabularClassifier from flash.text import SummarizationTask, TextClassifier, TranslationTask -from tests.helpers.utils import _AUDIO_TESTING, _IMAGE_TESTING, _TABULAR_TESTING, _TEXT_TESTING +from tests.helpers.utils import _AUDIO_TESTING, _GRAPH_TESTING, _IMAGE_TESTING, _TABULAR_TESTING, _TEXT_TESTING # ======== Mock functions ======== @@ -305,6 +306,22 @@ def test_task_datapipeline_save(tmpdir): reason="text packages aren't installed", ), ), + pytest.param( + GraphClassifier, + "0.6.0/graph_classification_model.pt", + marks=pytest.mark.skipif( + not _GRAPH_TESTING, + reason="graph packages aren't installed", + ), + ), + pytest.param( + GraphEmbedder, + "0.6.0/graph_classification_model.pt", + marks=pytest.mark.skipif( + not _GRAPH_TESTING, + reason="graph packages aren't installed", + ), + ), ], ) def test_model_download(tmpdir, cls, filename): diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index eeeb725ee9..1c351c84ac 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -108,6 +108,10 @@ "graph_classification.py", marks=pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed"), ), + pytest.param( + "graph_embedder.py", + marks=pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed"), + ), ], ) def test_example(tmpdir, file): diff --git a/tests/graph/classification/test_data.py b/tests/graph/classification/test_data.py index 8cce12dd04..a588390d17 100644 --- a/tests/graph/classification/test_data.py +++ b/tests/graph/classification/test_data.py @@ -14,11 +14,11 @@ import pytest from flash.core.data.transforms import merge_transforms -from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE +from flash.core.utilities.imports import _GRAPH_AVAILABLE from flash.graph.classification.data import GraphClassificationData, GraphClassificationInputTransform from tests.helpers.utils import _GRAPH_TESTING -if _TORCH_GEOMETRIC_AVAILABLE: +if _GRAPH_AVAILABLE: from torch_geometric.datasets import TUDataset from torch_geometric.transforms import OneHotDegree diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py index 8392e8010d..271e7ecab5 100644 --- a/tests/graph/classification/test_model.py +++ b/tests/graph/classification/test_model.py @@ -19,12 +19,12 @@ from flash import Trainer from flash.__main__ import main from flash.core.data.data_pipeline import DataPipeline -from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE +from flash.core.utilities.imports import _GRAPH_AVAILABLE from flash.graph.classification import GraphClassifier from flash.graph.classification.data import GraphClassificationInputTransform from tests.helpers.utils import _GRAPH_TESTING -if _TORCH_GEOMETRIC_AVAILABLE: +if _GRAPH_AVAILABLE: from torch_geometric import datasets diff --git a/tests/graph/embedding/__init__.py b/tests/graph/embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/graph/embedding/test_model.py b/tests/graph/embedding/test_model.py new file mode 100644 index 0000000000..f7c15b1095 --- /dev/null +++ b/tests/graph/embedding/test_model.py @@ -0,0 +1,63 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from flash import Trainer +from flash.core.data.data_pipeline import DataPipeline +from flash.core.utilities.imports import _GRAPH_AVAILABLE +from flash.graph.classification.data import GraphClassificationInputTransform +from flash.graph.classification.model import GraphClassifier +from flash.graph.embedding.model import GraphEmbedder +from tests.helpers.utils import _GRAPH_TESTING + +if _GRAPH_AVAILABLE: + from torch_geometric import datasets + + +@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +def test_smoke(): + """A simple test that the class can be instantiated from a GraphClassifier backbone.""" + model = GraphEmbedder(GraphClassifier(num_features=1, num_classes=1).backbone) + assert model is not None + + +@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +def test_not_trainable(tmpdir): + """Tests that the model gives an error when training, validating, or testing.""" + tudataset = datasets.TUDataset(root=tmpdir, name="KKI") + model = GraphEmbedder(GraphClassifier(num_features=1, num_classes=1).backbone) + model.data_pipeline = DataPipeline(input_transform=GraphClassificationInputTransform()) + dl = torch.utils.data.DataLoader(tudataset, batch_size=4) + trainer = Trainer(default_root_dir=tmpdir, num_sanity_val_steps=0) + with pytest.raises(NotImplementedError, match="Training a `GraphEmbedder` is not supported."): + trainer.fit(model, dl) + + with pytest.raises(NotImplementedError, match="Validating a `GraphEmbedder` is not supported."): + trainer.validate(model, dl) + + with pytest.raises(NotImplementedError, match="Testing a `GraphEmbedder` is not supported."): + trainer.test(model, dl) + + +@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed") +def test_predict_dataset(tmpdir): + """Tests that we can generate embeddings from a pytorch geometric dataset.""" + tudataset = datasets.TUDataset(root=tmpdir, name="KKI") + model = GraphEmbedder( + GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes).backbone + ) + data_pipe = DataPipeline(input_transform=GraphClassificationInputTransform()) + out = model.predict(tudataset, data_source="datasets", data_pipeline=data_pipe) + assert isinstance(out[0], torch.Tensor)