Skip to content

Commit

Permalink
Introduce RenateLightningModule (#301)
Browse files Browse the repository at this point in the history
  • Loading branch information
wistuba authored Jun 21, 2023
1 parent e4984aa commit c526510
Show file tree
Hide file tree
Showing 30 changed files with 472 additions and 269 deletions.
2 changes: 1 addition & 1 deletion examples/getting_started/renate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def buffer_transform() -> Callable:


def metrics_fn() -> Dict:
return {"my_accuracy": Accuracy()}
return {"accuracy": Accuracy()}


def loss_fn() -> torch.nn.Module:
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp_finetuning/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

run_training_job(
config_space=config_space,
mode="max",
metric="val_accuracy",
mode="min",
metric="val_loss",
updater="ER", # we train with Experience Replay
max_epochs=5,
config_file="renate_config.py",
Expand Down
9 changes: 7 additions & 2 deletions examples/simple_classifier_cifar10/renate_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional
from typing import Callable, Dict, Optional

import torch
from torchmetrics import Accuracy
from torchvision import transforms

import renate.defaults as defaults
Expand Down Expand Up @@ -36,7 +37,7 @@ def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) ->
)
class_incremental_scenario = ClassIncrementalScenario(
data_module=data_module,
class_groupings=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]],
class_groupings=((0, 1, 2, 3, 4), (5, 6, 7, 8, 9)),
chunk_id=chunk_id,
)
return class_incremental_scenario
Expand Down Expand Up @@ -65,3 +66,7 @@ def buffer_transform() -> Callable:

def loss_fn() -> torch.nn.Module:
return torch.nn.CrossEntropyLoss(reduction="none")


def metrics_fn() -> Dict:
return {"accuracy": Accuracy()}
9 changes: 7 additions & 2 deletions examples/train_mlp_locally/renate_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional
from typing import Callable, Dict, Optional

import torch
from torchmetrics import Accuracy
from torchvision.transforms import transforms

from renate import defaults
Expand All @@ -27,7 +28,7 @@ def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) ->

class_incremental_scenario = ClassIncrementalScenario(
data_module=data_module,
class_groupings=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]],
class_groupings=((0, 1, 2, 3, 4), (5, 6, 7, 8, 9)),
chunk_id=chunk_id,
)
return class_incremental_scenario
Expand All @@ -52,3 +53,7 @@ def train_transform() -> Callable:

def loss_fn() -> torch.nn.Module:
return torch.nn.CrossEntropyLoss(reduction="none")


def metrics_fn() -> Dict:
return {"accuracy": Accuracy()}
11 changes: 8 additions & 3 deletions src/renate/benchmark/experiment_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
import wild_time_data
from torchmetrics import Accuracy
from torchvision.transforms import transforms
from transformers import AutoTokenizer

Expand Down Expand Up @@ -291,7 +292,7 @@ def _get_normalize_transform(dataset_name):
)


def train_transform(dataset_name: str) -> Optional[transforms.Compose]:
def train_transform(dataset_name: str) -> Optional[Callable]:
"""Returns a transform function to be used in the training."""
if dataset_name in [
"MNIST",
Expand All @@ -317,7 +318,7 @@ def train_transform(dataset_name: str) -> Optional[transforms.Compose]:
raise ValueError(f"Unknown dataset `{dataset_name}`.")


def test_transform(dataset_name: str) -> Optional[transforms.Normalize]:
def test_transform(dataset_name: str) -> Optional[Callable]:
"""Returns a transform function to be used for validation or testing."""
if dataset_name in [
"MNIST",
Expand All @@ -335,3 +336,7 @@ def test_transform(dataset_name: str) -> Optional[transforms.Normalize]:
]
)
raise ValueError(f"Unknown dataset `{dataset_name}`.")


def metrics_fn() -> Dict:
return {"accuracy": Accuracy()}
28 changes: 12 additions & 16 deletions src/renate/benchmark/experimentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,16 @@ def experiment_config_file():
return str(Path(renate.__path__[0]) / "benchmark" / "experiment_config.py")


def create_cumulative_metrics(task: defaults.SUPPORTED_TASKS_TYPE) -> List[Tuple[str, Callable]]:
def create_cumulative_metrics() -> List[Tuple[str, Callable]]:
"""Gets the cumulative metrics for a given task along with a name of the metric to include in
any potential results table.
Args:
task: Whether classification or regression, for now.
"""
if task == "classification":
return [
("Average Accuracy", average_accuracy),
("Forgetting", forgetting),
("Forward Transfer", forward_transfer),
("Backward Transfer", backward_transfer),
]
else:
raise NotImplementedError(f"Task {task} not implemented.")
return [
("Average Accuracy", average_accuracy),
("Forgetting", forgetting),
("Forward Transfer", forward_transfer),
("Backward Transfer", backward_transfer),
]


def cumulative_metrics_summary(
Expand Down Expand Up @@ -183,8 +177,10 @@ def execute_experiment_job(
deterministic_trainer: When true the Trainer adopts a deterministic behaviour also on GPU.
In this function this parameter is set to True by default.
job_name: Name of the experiment job.
strategy: String denoting lightning distributed strategy.
precision: String for which precision to use.
strategy: Name of the distributed training strategy to use.
`More details <https://lightning.ai/docs/pytorch/stable/extensions/strategy.html>`__
precision: Type of bit precision to use.
`More details <https://lightning.ai/docs/pytorch/stable/common/precision_basic.html>`__
retain_intermediate_state: Flag to retain models and buffer states after each
task update. This is useful when training with large datasets that might cause storage
issues.
Expand Down Expand Up @@ -385,7 +381,7 @@ def _execute_experiment_job_locally(
logger.info(f"### Results after update {update_id + 1}: ###")
logger.info(df)

cumulative_metrics = create_cumulative_metrics("classification")
cumulative_metrics = create_cumulative_metrics()
df = cumulative_metrics_summary(results, cumulative_metrics, num_updates - 1)
save_pandas_df_to_csv(df, defaults.metric_summary_file(logs_url))
if not retain_intermediate_state:
Expand Down
52 changes: 35 additions & 17 deletions src/renate/benchmark/scenarios.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import abc
from typing import Callable, List, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -26,7 +26,7 @@ class Scenario(abc.ABC):
subsequent instantiations. The seed argument is required for these scenarios.
Args:
data_module: The source RenateDataModule for the the user data.
data_module: The source RenateDataModule for the user data.
num_tasks: The total number of expected tasks for experimentation.
chunk_id: The data chunk to load in for the training or validation data.
seed: Seed used to fix random number generation.
Expand All @@ -45,9 +45,12 @@ def __init__(
self._verify_chunk_id(chunk_id)
self._chunk_id = chunk_id
self._seed = seed
self._train_data: Dataset = None
self._val_data: Dataset = None
self._test_data: List[Dataset] = None
self._train_data: Optional[Dataset] = None
self._val_data: Optional[Dataset] = None
self._test_data: Optional[List[Dataset]] = None
self._train_collate_fn: Optional[Callable] = None
self._val_collate_fn: Optional[Callable] = None
self._test_collate_fn: Optional[Callable] = None

def prepare_data(self) -> None:
"""Downloads datasets."""
Expand All @@ -56,7 +59,10 @@ def prepare_data(self) -> None:
@abc.abstractmethod
def setup(self) -> None:
"""Sets up the scenario."""
pass
self._data_module.setup()
self._train_collate_fn = self._data_module.train_collate_fn()
self._val_collate_fn = self._data_module.val_collate_fn()
self._test_collate_fn = self._data_module.test_collate_fn()

def train_data(self) -> Dataset:
"""Returns training dataset with respect to current `chunk_id`."""
Expand All @@ -70,6 +76,18 @@ def test_data(self) -> List[Dataset]:
"""Returns the test data with respect to all tasks in `num_tasks`."""
return self._test_data

def train_collate_fn(self) -> Optional[Callable]:
"""Returns collate_fn for train DataLoader."""
return self._train_collate_fn

def val_collate_fn(self) -> Optional[Callable]:
"""Returns collate_fn for validation DataLoader."""
return self._val_collate_fn

def test_collate_fn(self) -> Optional[Callable]:
"""Returns collate_fn for test DataLoader."""
return self._test_collate_fn

def _verify_chunk_id(self, chunk_id: int) -> None:
"""A helper function to verify that the `chunk_id` is valid."""
assert 0 <= chunk_id < self._num_tasks
Expand All @@ -90,7 +108,7 @@ class BenchmarkScenario(Scenario):
"""

def setup(self) -> None:
self._data_module.setup()
super().setup()
self._train_data = self._data_module.train_data()
self._val_data = self._data_module.val_data()
self._test_data = self._data_module._test_data
Expand All @@ -108,7 +126,7 @@ class ClassIncrementalScenario(Scenario):
and `y` is the class id.
Args:
data_module: The source RenateDataModule for the the user data.
data_module: The source RenateDataModule for the user data.
chunk_id: The data chunk to load in for the training or validation data.
class_groupings: List of lists, describing the division of the classes for respective tasks.
"""
Expand All @@ -117,14 +135,14 @@ def __init__(
self,
data_module: RenateDataModule,
chunk_id: int,
class_groupings: Tuple[Tuple[int]],
class_groupings: Tuple[Tuple[int, ...], ...],
) -> None:
super().__init__(data_module, len(class_groupings), chunk_id)
self._class_groupings = class_groupings

def setup(self) -> None:
"""Make assignments: val/train/test splits."""
self._data_module.setup()
super().setup()
self._train_data = self._get_task_subset(
self._data_module.train_data(), chunk_id=self._chunk_id
)
Expand Down Expand Up @@ -178,7 +196,7 @@ def __init__(
self._transforms = transforms

def setup(self) -> None:
self._data_module.setup()
super().setup()
self._split_and_assign_train_and_val_data()
self._train_data = _TransformedDataset(
self._train_data, transform=self._transforms[self._chunk_id]
Expand Down Expand Up @@ -249,7 +267,7 @@ class IIDScenario(Scenario):

def setup(self) -> None:
"""Make assignments: val/train/test splits."""
self._data_module.setup()
super().setup()
proportions = [1 / self._num_tasks for _ in range(self._num_tasks)]
self._train_data = randomly_split_data(
self._data_module.train_data(), proportions, self._seed
Expand Down Expand Up @@ -304,7 +322,7 @@ def _split(self, dataset: Dataset) -> List[Dataset]:

def setup(self) -> None:
"""Make assignments: val/train/test splits."""
self._data_module.setup()
super().setup()
train_data = self._data_module.train_data()
self._train_data = self._split(train_data)[self._chunk_id]
val_data = self._data_module.val_data()
Expand All @@ -324,12 +342,12 @@ class FeatureSortingScenario(_SortingScenario):
the features.
Args:
data_module: The source RenateDataModule for the the user data.
data_module: The source RenateDataModule for the user data.
num_tasks: The total number of expected tasks for experimentation.
feature_idx: Index of the feature by which to sort. This index refers to the input features
`x` of a single data point, i.e., no batch dimension. If the tensor `x` has more than
one dimension, this indexes along the 0-dim while additional dimensions will be averaged
out. Hence, for images, `feature_idx` refers to a color channel and we sort by mean
out. Hence, for images, `feature_idx` refers to a color channel, and we sort by mean
color channel value.
randomness: A value between 0 and 1. For a dataset with ``N`` data points,
``0.5 * N * randomness`` random pairs are swapped.
Expand Down Expand Up @@ -388,7 +406,7 @@ class WildTimeScenario(Scenario):
the test set is all data up to the current time step.
Args:
data_module: The source RenateDataModule for the the user data.
data_module: The source RenateDataModule for the user data.
num_tasks: The total number of expected tasks for experimentation.
chunk_id: The data chunk to load in for the training or validation data.
seed: Seed used to fix random number generation.
Expand All @@ -408,7 +426,7 @@ def __init__(
def setup(self) -> None:
"""Sets up the scenario."""
self._data_module.time_step = self._chunk_id
self._data_module.setup()
super().setup()
self._train_data = self._data_module.train_data()
self._val_data = self._data_module.val_data()
self._test_data = []
Expand Down
2 changes: 2 additions & 0 deletions src/renate/cli/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def run(self):
model_updater.update(
train_dataset=data_module.train_data(),
val_dataset=data_module.val_data(),
train_dataset_collate_fn=data_module.train_collate_fn(),
val_dataset_collate_fn=data_module.val_collate_fn(),
task_id=args.task_id,
)

Expand Down
17 changes: 16 additions & 1 deletion src/renate/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
import os
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import pandas as pd
import torch
Expand Down Expand Up @@ -55,6 +55,9 @@ def __init__(
self._train_data: Optional[Dataset] = None
self._val_data: Optional[Dataset] = None
self._test_data: Optional[Dataset] = None
self._train_collate_fn: Optional[Callable] = None
self._val_collate_fn: Optional[Callable] = None
self._test_collate_fn: Optional[Callable] = None
assert 0.0 <= val_size <= 1.0
self._val_size = val_size
self._seed = seed
Expand Down Expand Up @@ -83,6 +86,18 @@ def test_data(self) -> Dataset:
"""Returns test dataset."""
return self._test_data

def train_collate_fn(self) -> Optional[Callable]:
"""Returns collate_fn for train DataLoader."""
return self._train_collate_fn

def val_collate_fn(self) -> Optional[Callable]:
"""Returns collate_fn for validation DataLoader."""
return self._val_collate_fn

def test_collate_fn(self) -> Optional[Callable]:
"""Returns collate_fn for test DataLoader."""
return self._test_collate_fn

def _verify_file(self, file_name: str) -> bool:
"""A helper function that verifies that the required dataset files are downloaded and
correct.
Expand Down
1 change: 0 additions & 1 deletion src/renate/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
FRAMEWORK_VERSION = "1.12.0"

TASK_ID = "default_task"
SUPPORTED_TASKS_TYPE = Literal["classification", "regression"]
WORKING_DIRECTORY = "renate_working_dir"
LOGGER = TensorBoardLogger
LOGGER_KWARGS = {
Expand Down
Loading

0 comments on commit c526510

Please sign in to comment.