-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce RenateLightningModule
#301
Conversation
Coverage reportThe coverage rate went from
Diff Coverage details (click to unfold)src/renate/data/data_module.py
src/renate/updaters/experimental/repeated_distill.py
src/renate/updaters/experimental/er.py
src/renate/updaters/avalanche/model_updater.py
src/renate/benchmark/scenarios.py
src/renate/updaters/experimental/offline_er.py
src/renate/utils/avalanche.py
src/renate/benchmark/experiment_config.py
src/renate/updaters/learner.py
src/renate/evaluation/evaluator.py
src/renate/benchmark/experimentation.py
src/renate/updaters/experimental/gdumb.py
src/renate/updaters/model_updater.py
src/renate/updaters/experimental/joint.py
|
@@ -317,7 +318,7 @@ def train_transform(dataset_name: str) -> Optional[transforms.Compose]: | |||
raise ValueError(f"Unknown dataset `{dataset_name}`.") | |||
|
|||
|
|||
def test_transform(dataset_name: str) -> Optional[transforms.Normalize]: | |||
def test_transform(dataset_name: str) -> Optional[Union[transforms.Normalize, transforms.Compose]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be made a callable? This is very specific to existing transformations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this also Optional
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because it can return None
src/renate/updaters/learner.py
Outdated
self._seed = seed | ||
self._task_id: str = defaults.TASK_ID | ||
self._train_dataset = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no typing :)
def training_step_unpack_batch( | ||
self, batch: Tuple[NestedTensors, torch.Tensor] | ||
) -> Tuple[NestedTensors, Any]: | ||
inputs, targets = batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This choice of what a batch comprises is specific to HF vs non-HF datasets and whether there is a buffer? So, why is this specified in the base class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I guess I should change to Any
src/renate/updaters/learner.py
Outdated
@@ -253,9 +217,15 @@ def training_epoch_end(self, outputs: List[Union[Tensor, Dict[str, Any]]]) -> No | |||
if not self.val_enabled: | |||
self._log_metrics() | |||
|
|||
def validation_step_unpack_batch( | |||
self, batch: Tuple[NestedTensors, torch.Tensor] | |||
) -> Tuple[NestedTensors, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above.
@@ -78,7 +78,7 @@ def execute_job(): | |||
mode="max", | |||
config_space=config_space, | |||
metric="val_accuracy", | |||
max_time=30, | |||
max_time=35, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is continually increasing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few small changes. I agree with the overall structure of what is being accomplished.
Split Learner into two classes. RenateLightningModule contains the basic logic while Learner adds all CL-related stuff such as val_buffer
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.