Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 14, 2021
1 parent c9b675f commit 571af04
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 23 deletions.
2 changes: 1 addition & 1 deletion flash/graph/GraphClassification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from flash.graph.classification.data import GraphClassificationData
from flash.graph.classification.model import GraphClassifier
from flash.graph.classification.model import GraphClassifier
24 changes: 14 additions & 10 deletions flash/graph/GraphClassification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,27 @@
import pathlib
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import networkx as nx
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException

import networkx as nx
from torch_geometric.data import Dataset, DataLoader

from torch_geometric.data import DataLoader, Dataset

from flash.core.classification import ClassificationDataPipeline
from flash.core.data.datamodule import DataModule
from flash.core.data.utils import _contains_any_tensor

'''
The structure we follow is DataSet -> DataLoader -> DataModule -> DataPipeline
'''

class BasicGraphDataset(Dataset):

class BasicGraphDataset(Dataset):
'''
#todo: Probably unnecessary having the following class.
'''

def __init__(self, root = None, processed_dir = 'processed', raw_dir = 'raw', transform=None, pre_transform=None, pre_filter=None):
def __init__(
self, root=None, processed_dir='processed', raw_dir='raw', transform=None, pre_transform=None, pre_filter=None
):

super(BasicGraphDataset, self).__init__(root, transform, pre_transform, pre_filter)

Expand Down Expand Up @@ -66,9 +65,11 @@ def len(self):

def get(self, idx):
data = torch.load(os.path.join(self.processed_dir, 'data_{}.pt'.format(idx)))
#TODO: Is data.pt the best way/file type to load the data?
#TODO: Is data.pt the best way/file type to load the data?
#TODO: Interface with networkx would probably go here with some option to say how to load it
return data


class FilepathDataset(torch.utils.data.Dataset):
"""Dataset that takes in filepaths and labels. Taken from image"""

Expand Down Expand Up @@ -117,6 +118,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Optional[int]]:
label = self.label_to_class_mapping[filename]
return graph, label


class FlashDatasetFolder(torch.utils.data.Dataset):
"""A generic data loader where the samples are arranged in this way: ::
Expand Down Expand Up @@ -155,7 +157,8 @@ def __init__(
self,
root: str,
loader: Callable,
extensions: Tuple[str] = Graph_EXTENSIONS, #todo: Graph_EXTENSIONS is not defined. In PyG the extension .pt is used
extensions: Tuple[
str] = Graph_EXTENSIONS, #todo: Graph_EXTENSIONS is not defined. In PyG the extension .pt is used
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable] = None,
Expand Down Expand Up @@ -247,6 +250,7 @@ def _make_dataset(self, dir, class_to_idx):
def _is_graph_file(self, filename):
return any(filename.endswith(extension) for extension in self.extensions)


class GraphClassificationData(DataModule):
"""Data module for graph classification tasks."""

Expand Down Expand Up @@ -298,7 +302,6 @@ def from_filepaths(
if isinstance(test_filepaths, str):
test_filepaths = [os.path.join(test_filepaths, x) for x in os.listdir(test_filepaths)]


train_ds = FilepathDataset(
filepaths=train_filepaths,
labels=train_labels,
Expand Down Expand Up @@ -397,6 +400,7 @@ def from_folders(
)
return datamodule


class GraphClassificationDataPipeline(ClassificationDataPipeline):

def __init__(
Expand Down
25 changes: 13 additions & 12 deletions flash/graph/GraphClassification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Optional, Tuple, Type, Union, Mapping, Sequence, Union
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
import pytorch_lightning as pl
import torch
from pytorch_lightning.metrics import Accuracy
from torch import nn
from torch.nn import functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

from torch_geometric.nn import GCNConv, global_mean_pool

from flash.core.classification import ClassificationTask
from flash.core.data import DataPipeline


class GraphClassifier(ClassificationTask):
"""Task that classifies graphs.
Expand Down Expand Up @@ -57,10 +56,10 @@ def __init__(

#sizes = [input_size] + hidden + [num_classes]
if model == None:
self.model = GCN(in_features = num_features, hidden_channels=hidden, out_features = num_classes)
self.model = GCN(in_features=num_features, hidden_channels=hidden, out_features=num_classes)

super().__init__(
model = model,
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
Expand All @@ -75,8 +74,10 @@ def forward(self, data) -> Any:
def default_pipeline() -> ClassificationDataPipeline:
return GraphClassificationData.default_pipeline()


#Taken from https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=CN3sRVuaQ88l
class GCN(pl.LightningModule):

def __init__(self, num_features, hidden_channels, num_classes):
super(GCN, self).__init__()
torch.manual_seed(12345)
Expand All @@ -86,7 +87,7 @@ def __init__(self, num_features, hidden_channels, num_classes):
self.lin = Linear(hidden_channels, num_classes)

def forward(self, x, edge_index, batch):
# 1. Obtain node embeddings
# 1. Obtain node embeddings
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
Expand All @@ -102,11 +103,11 @@ def forward(self, x, edge_index, batch):

return x

def training_step(self, batch, batch_idx): #todo: is this needed?
def training_step(self, batch, batch_idx): #todo: is this needed?
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def configure_optimizers(self): #todo: is this needed?
return torch.optim.Adam(self.parameters(), lr=0.02)

def configure_optimizers(self): #todo: is this needed?
return torch.optim.Adam(self.parameters(), lr=0.02)

0 comments on commit 571af04

Please sign in to comment.