Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Pipeline taken from images.
Browse files Browse the repository at this point in the history
I'm unsure how to adapt
  • Loading branch information
pablo authored and Borda committed Feb 12, 2021
1 parent 2844d32 commit 2f1cf01
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion flash/graph/GraphClassification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,4 +377,33 @@ def from_folders(
datamodule.data_pipeline = GraphClassificationDataPipeline(
train_transform=train_transform, valid_transform=valid_transform, loader=loader
)
return datamodule
return datamodule

class GraphClassificationDataPipeline(ClassificationDataPipeline):

def __init__(
self,
train_transform: Optional[Callable] = None,
valid_transform: Optional[Callable] = None,
use_valid_transform: bool = True,
loader: Callable = torch.load
):
self._train_transform = train_transform
self._valid_transform = valid_transform
self._use_valid_transform = use_valid_transform
self._loader = loader

def before_collate(self, samples: Any) -> Any:
if _contains_any_tensor(samples):
return samples

if isinstance(samples, str):
samples = [samples]
if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples):
outputs = []
for sample in samples:
output = self._loader(sample)
transform = self._valid_transform if self._use_valid_transform else self._train_transform
outputs.append(transform(output))
return outputs
raise MisconfigurationException("The samples should either be a tensor or a list of paths.")

0 comments on commit 2f1cf01

Please sign in to comment.