Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Choice of model implemented (you can pass a model to GraphClassifier)
Browse files Browse the repository at this point in the history
The class BasicGraphDataset in graphClassification/data.py is probably unneded
  • Loading branch information
pablo authored and pablo committed May 14, 2021
1 parent a568bda commit 1f326e7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
26 changes: 22 additions & 4 deletions flash/graph/GraphClassification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class BasicGraphDataset(Dataset):

'''
Probably unnecessary having the following class.
#todo: Probably unnecessary having the following class.
'''

def __init__(self, root = None, processed_dir = 'processed', raw_dir = 'raw', transform=None, pre_transform=None, pre_filter=None):
Expand Down Expand Up @@ -94,6 +94,10 @@ def __init__(
if self.has_labels:
self.label_to_class_mapping = {v: k for k, v in enumerate(list(sorted(list(set(self.fnames)))))}

@property
def has_dict_labels(self) -> bool:
return isinstance(self.labels, dict)

@property
def has_labels(self) -> bool:
return self.labels is not None
Expand All @@ -105,6 +109,10 @@ def __getitem__(self, index: int) -> Tuple[Any, Optional[int]]:
filename = self.fnames[index]
graph = self.loader(filename)
label = None
if self.has_dict_labels:
name = os.path.splitext(filename)[0]
name = os.path.basename(name)
label = self.labels[name]
if self.has_labels:
label = self.label_to_class_mapping[filename]
return graph, label
Expand Down Expand Up @@ -136,7 +144,7 @@ class FlashDatasetFolder(torch.utils.data.Dataset):
with_targets: Whether to include targets
graph_paths: List of graph paths to load. Only used when ``with_targets=False``
Attributes:
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
Expand All @@ -147,7 +155,7 @@ def __init__(
self,
root: str,
loader: Callable,
extensions: Tuple[str] = Graph_EXTENSIONS,
extensions: Tuple[str] = Graph_EXTENSIONS, #todo: Graph_EXTENSIONS is not defined. In PyG the extension .pt is used
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable] = None,
Expand Down Expand Up @@ -175,7 +183,7 @@ def __init__(
else:
if not graph_paths:
raise MisconfigurationException(
"`FlashDatasetFolder(with_target=False)` but no `img_paths` were provided"
"`FlashDatasetFolder(with_target=False)` but no `graph_paths` were provided"
)
self.samples = graph_paths

Expand Down Expand Up @@ -281,6 +289,16 @@ def from_filepaths(
>>> _data = GraphClassificationData.from_filepaths(["a.pt", "b.pt"], [0, 1]) # doctest: +SKIP
"""

# enable passing in a string which loads all files in that folder as a list
if isinstance(train_filepaths, str):
train_filepaths = [os.path.join(train_filepaths, x) for x in os.listdir(train_filepaths)]
if isinstance(valid_filepaths, str):
valid_filepaths = [os.path.join(valid_filepaths, x) for x in os.listdir(valid_filepaths)]
if isinstance(test_filepaths, str):
test_filepaths = [os.path.join(test_filepaths, x) for x in os.listdir(test_filepaths)]


train_ds = FilepathDataset(
filepaths=train_filepaths,
labels=train_labels,
Expand Down
30 changes: 21 additions & 9 deletions flash/graph/GraphClassification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from typing import Any, Callable, List, Optional, Tuple, Type, Union, Mapping, Sequence, Union

import torch
import pytorch_lightning as pl
from pytorch_lightning.metrics import Accuracy
from torch import nn
from torch.nn import functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


from flash.core.classification import ClassificationTask
from flash.core.data import DataPipeline

Expand All @@ -47,36 +49,37 @@ def __init__(
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[Callable, Mapping, Sequence, None] = [Accuracy()],
learning_rate: float = 1e-3,
model: torch.nn.Module = None,
):

if isinstance(hidden, int):
hidden = [hidden]

#sizes = [input_size] + hidden + [num_classes]
if model == None:
self.model = GCN(in_features = num_features, hidden_channels=hidden, out_features = num_classes)

super().__init__(
model = GCN(in_features = num_features, hidden_channels=hidden, out_features = num_classes),
model = model,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
)

#train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

def forward(self, data) -> Any:
x = self.model(data.x, data.edge_index, data.batch) #This line is probably something to change
x = self.model(data.x, data.edge_index, data.batch)
return self.head(x)

@staticmethod
def default_pipeline() -> ClassificationDataPipeline:
return GraphClassificationData.default_pipeline()

#Taken from https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=CN3sRVuaQ88l
class GCN(torch.nn.Module):
class GCN(pl.LightningModule):
def __init__(self, num_features, hidden_channels, num_classes):
super(GCN, self).__init__() #I don't understand why we need to call super here with GCN as an argument
#torch.manual_seed(12345)
super(GCN, self).__init__()
torch.manual_seed(12345)
self.conv1 = GCNConv(num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.conv3 = GCNConv(hidden_channels, hidden_channels)
Expand All @@ -96,5 +99,14 @@ def forward(self, x, edge_index, batch):
# 3. Apply a final classifier
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)

return x

return x

def training_step(self, batch, batch_idx): #todo: is this needed?
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss

def configure_optimizers(self): #todo: is this needed?
return torch.optim.Adam(self.parameters(), lr=0.02)

0 comments on commit 1f326e7

Please sign in to comment.