Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add support for distributed training #257

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 39 additions & 20 deletions learn2learn/algorithms/lightning/lightning_episodic_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

try:
from pytorch_lightning import LightningModule
from pytorch_lightning.trainer.states import TrainerFn
except ImportError:
from learn2learn.utils import _ImportRaiser

Expand Down Expand Up @@ -69,6 +70,15 @@ def add_model_specific_args(parent_parser):
)
return parser

@property
def should_cache_data_on_validate(self) -> bool:
# some algorithm requires to be fitted on the new labelled data.
return False

@property
def should_fit_on_validate(self) -> bool:
return self.should_cache_data_on_validate and self.trainer.state.fn == TrainerFn.VALIDATING

def training_step(self, batch, batch_idx):
train_loss, train_accuracy = self.meta_learn(
batch, batch_idx, self.train_ways, self.train_shots, self.train_queries
Expand All @@ -92,26 +102,34 @@ def training_step(self, batch, batch_idx):
return train_loss

def validation_step(self, batch, batch_idx):
valid_loss, valid_accuracy = self.meta_learn(
batch, batch_idx, self.test_ways, self.test_shots, self.test_queries
)
self.log(
"valid_loss",
valid_loss.item(),
on_step=False,
on_epoch=True,
prog_bar=False,
logger=True,
)
self.log(
"valid_accuracy",
valid_accuracy.item(),
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return valid_loss.item()
if self.should_fit_on_validate:
# used for the algorithm to store the supports data
self.cache_on_validate_step(batch, batch_idx)
else:
valid_loss, valid_accuracy = self.meta_learn(
batch, batch_idx, self.test_ways, self.test_shots, self.test_queries
)
self.log(
"valid_loss",
valid_loss.item(),
on_step=False,
on_epoch=True,
prog_bar=False,
logger=True,
)
self.log(
"valid_accuracy",
valid_accuracy.item(),
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return valid_loss.item()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def validation_epoch_end(self, outputs):
if self.should_fit_on_validate:
self.fit_on_validate_epoch_end()

def test_step(self, batch, batch_idx):
test_loss, test_accuracy = self.meta_learn(
Expand Down Expand Up @@ -143,3 +161,4 @@ def configure_optimizers(self):
gamma=self.scheduler_decay,
)
return [optimizer], [lr_scheduler]

tchaton marked this conversation as resolved.
Show resolved Hide resolved
25 changes: 24 additions & 1 deletion learn2learn/algorithms/lightning/lightning_protonet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
import numpy as np
import torch

from typing import Any
from torch import nn
from learn2learn.utils import accuracy
from learn2learn.nn import PrototypicalClassifier
Expand Down Expand Up @@ -97,6 +97,9 @@ def __init__(self, features, loss=None, **kwargs):
self.features = torch.nn.DataParallel(self.features)
self.classifier = PrototypicalClassifier(distance=self.distance_metric)

self.support = []
self.support_labels = []

@staticmethod
def add_model_specific_args(parent_parser):
parser = LightningEpisodicModule.add_model_specific_args(parent_parser)
Expand All @@ -112,6 +115,10 @@ def add_model_specific_args(parent_parser):
)
return parser

@property
def should_cache_data_on_validate(self) -> bool:
return True

def meta_learn(self, batch, batch_idx, ways, shots, queries):
self.features.train()
data, labels = batch
Expand Down Expand Up @@ -139,3 +146,19 @@ def meta_learn(self, batch, batch_idx, ways, shots, queries):
eval_loss = self.loss(logits, query_labels)
eval_accuracy = accuracy(logits, query_labels)
return eval_loss, eval_accuracy

def cache_on_validate_step(self, batch, batch_idx):
data, labels = batch
embeddings = self.features(data)
for e, l in zip(embeddings, labels):
self.support.append(e)
self.support_labels.append(l)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def fit_on_validate_epoch_end(self):
self.classifier.fit_(torch.stack(self.support), torch.tensor(self.support_labels))
self.support = []
self.support_labels = []

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
embeddings = self.features(batch)
return self.classifier(embeddings)
156 changes: 131 additions & 25 deletions learn2learn/utils/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,139 @@
"""
Some utilities to interface with PyTorch Lightning.
"""

from typing import Optional, Callable
import learn2learn as l2l
import pytorch_lightning as pl
from torch.utils.data._utils.worker import get_worker_info
from torch.utils.data import IterableDataset
import sys
import tqdm

class TaskDataParallel(IterableDataset):

class EpisodicBatcher(pl.LightningDataModule):
def __init__(
self,
tasks: l2l.data.TaskDataset,
epoch_length: int,
devices: int = 1,
collate_fn: Optional[Callable] = None
):
"""
This class is used to sample epoch_length tasks to represent an epoch.

It should be used when using DataParallel

Args:
taskset: Dataset used to sample task.
epoch_length: The expected epoch length. This requires to be divisible by (num_workers * world_size).
tchaton marked this conversation as resolved.
Show resolved Hide resolved
devices: Number of devices being used.
collate_fn: The collate_fn to be applied on multiple tasks

"""
self.tasks = tasks
self.epoch_length = epoch_length
self.devices = devices

if epoch_length % devices != 0:
raise Exception("The `epoch_length` should be the number of `devices`.")

self.collate_fn = collate_fn
self.counter = 0

def __iter__(self) -> 'TaskDataParallel':
self.counter = 0
return self

def __next__(self):
if self.counter >= len(self):
raise StopIteration
self.counter += self.devices
tasks = []
for _ in range(self.devices):
for item in self.tasks.sample():
tasks.append(item)
if self.collate_fn:
tasks = self.collate_fn(tasks)
return tasks

def __len__(self):
return self.epoch_length


class TaskDistributedDataParallel(IterableDataset):

def __init__(
self,
taskset: l2l.data.TaskDataset,
global_rank: int,
world_size: int,
num_workers: int,
epoch_length: int,
seed: int,
requires_divisible: bool = True,
):
"""
This class is used to sample tasks in a distributed setting such as DDP with multiple workers.

Note: This won't work as expected if `num_workers = 0` and several dataloaders are being iterated on at the same time.

Args:
taskset: Dataset used to sample task.
global_rank: Rank of the current process.
world_size: Total of number of processes.
num_workers: Number of workers to be provided to the DataLoader.
epoch_length: The expected epoch length. This requires to be divisible by (num_workers * world_size).
seed: The seed will be used on __iter__ call and should be the same for all processes.

"""
self.taskset = taskset
self.global_rank = global_rank
self.world_size = world_size
self.num_workers = 1 if num_workers == 0 else num_workers
self.worker_world_size = self.world_size * self.num_workers
self.epoch_length = epoch_length
self.seed = seed
self.iteration = 0
self.iteration = 0
self.requires_divisible = requires_divisible
self.counter = 0

if requires_divisible and epoch_length % self.worker_world_size != 0:
raise Exception("The `epoch_length` should be divisible by `world_size`.")

def __len__(self) -> int:
return self.epoch_length // self.world_size

@property
def worker_id(self) -> int:
worker_info = get_worker_info()
return worker_info.id if worker_info else 0

@property
def worker_rank(self) -> int:
is_global_zero = self.global_rank == 0
return self.global_rank + self.worker_id + int(not is_global_zero and self.num_workers > 1)

def __iter__(self):
self.iteration += 1
self.counter = 0
pl.seed_everything(self.seed + self.iteration)
return self

def __next__(self):
if self.counter >= len(self):
raise StopIteration
task_descriptions = []
for _ in range(self.worker_world_size):
task_descriptions.append(self.taskset.sample_task_description())

"""
nc
"""
data = self.taskset.get_task(task_descriptions[self.worker_rank])
self.counter += 1
return data



class EpisodicBatcher(pl.LightningDataModule):

def __init__(
self,
Expand All @@ -32,38 +154,22 @@ def __init__(
self.test_tasks = test_tasks
self.epoch_length = epoch_length

@staticmethod
def epochify(taskset, epoch_length):
class Epochifier(object):
def __init__(self, tasks, length):
self.tasks = tasks
self.length = length

def __getitem__(self, *args, **kwargs):
return self.tasks.sample()

def __len__(self):
return self.length

return Epochifier(taskset, epoch_length)

def train_dataloader(self):
return EpisodicBatcher.epochify(
return Epochifier(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was Epochifier inadvertently removed? Or should it be replaced by TaskDataParallel?

Also, if I remember correctly Epochifier was preventing us from using pytorch_lightning > 1.0.2. I believe the fix was just to inherit from a data class in lightning. @nightlessbaron do you remember which class and if that was enough?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seba-1511, quick answer to your first question: Epochifier is replaced by TaskDataParallel

Copy link
Contributor

@nightlessbaron nightlessbaron Sep 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seba-1511, as @pietrolesci mentioned, in the benchmark, we replaced Epochifier with TaskDistributedParallel. The downside is windows users can't make use of it as stated in #5358.

I believe the fix was just to inherit from a data class in lightning.

I couldn't figure it out. That is still unresolved.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @seba-1511, yes it was inadvertently removed.

However, I added a TODO. It should be provided replaced by the TaskDataParallel, TaskDistributedDataParallel when tested on your side.

I will copy those files on the Flash side to no block the integration until the next learn2learn release.

Best,
T.C

self.train_tasks,
self.epoch_length,
)

def val_dataloader(self):
return EpisodicBatcher.epochify(
return Epochifier(
self.validation_tasks,
self.epoch_length,
)

def test_dataloader(self):
length = self.epoch_length
return EpisodicBatcher.epochify(
return Epochifier(
self.test_tasks,
length,
self.epoch_length,
)


Expand Down
8 changes: 8 additions & 0 deletions tests/unit/algorithms/lightning_protonet_test_notravis.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ def test_protonets(self):
verbose=False,
)
self.assertTrue(acc[0]["valid_accuracy"] >= 0.20)
trainer.validate(
val_dataloaders=tasksets.validation,
verbose=False,
)
predictions = trainer.predict(
test_dataloaders=tasksets.validation,
verbose=False,
)


if __name__ == "__main__":
Expand Down