From 46653107c0937d16368705c1e8d6d2648443cc59 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 15 Nov 2021 12:01:28 +0000 Subject: [PATCH] Try fix --- flash/image/classification/adapters.py | 42 ++++++++------------------ 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/flash/image/classification/adapters.py b/flash/image/classification/adapters.py index 9a5faeb092..922e2666ab 100644 --- a/flash/image/classification/adapters.py +++ b/flash/image/classification/adapters.py @@ -57,25 +57,6 @@ def remap(self, data, mapping): return data -class NoModule: - - """This class is used to prevent nn.Module infinite recursion.""" - - def __init__(self, task): - self.task = task - - def __getattr__(self, key): - if key != "task": - return getattr(self.task, key) - return self.task - - def __setattr__(self, key: str, value: Any) -> None: - if key == "task": - object.__setattr__(self, key, value) - return - setattr(self.task, key, value) - - class Model(torch.nn.Module): def __init__(self, backbone: torch.nn.Module, head: Optional[torch.nn.Module]): super().__init__() @@ -97,7 +78,6 @@ class Learn2LearnAdapter(Adapter): def __init__( self, - task: AdapterTask, backbone: torch.nn.Module, head: torch.nn.Module, algorithm_cls: Type[LightningModule], @@ -143,7 +123,6 @@ def __init__( super().__init__() - self._task = NoModule(task) self.backbone = backbone self.head = head self.algorithm_cls = algorithm_cls @@ -309,7 +288,9 @@ def from_task( "The `shots` should be provided training_strategy_kwargs={'shots'=...}. " "This is equivalent to the number of sample per label to select within a task." ) - return cls(task, backbone, head, algorithm, **kwargs) + adapter = cls(backbone, head, algorithm, **kwargs) + adapter["_task"] = task + return adapter def training_step(self, batch, batch_idx) -> Any: input = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) @@ -337,7 +318,7 @@ def _sanetize_batch_size(self, batch_size: int) -> int: warning_cache.warn( "When using a meta-learning training_strategy, the batch_size should be set to 1. " "HINT: You can modify the `meta_batch_size` to 100 for example by doing " - f"{type(self._task.task)}" + "(training_strategies_kwargs={'meta_batch_size': 100})" + f"{type(self._task)}" + "(training_strategies_kwargs={'meta_batch_size': 100})" ) return 1 @@ -486,10 +467,9 @@ class DefaultAdapter(Adapter): required_extras: str = "image" - def __init__(self, task: AdapterTask, backbone: torch.nn.Module, head: torch.nn.Module): + def __init__(self, backbone: torch.nn.Module, head: torch.nn.Module): super().__init__() - self._task = NoModule(task) self.backbone = backbone self.head = head @@ -503,23 +483,25 @@ def from_task( head: torch.nn.Module, **kwargs, ) -> Adapter: - return cls(task, backbone, head) + adapter = cls(backbone, head) + adapter.__dict__["_task"] = task + return adapter def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) - return Task.training_step(self._task.task, batch, batch_idx) + return Task.training_step(self._task, batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) - return Task.validation_step(self._task.task, batch, batch_idx) + return Task.validation_step(self._task, batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) - return Task.test_step(self._task.task, batch, batch_idx) + return Task.test_step(self._task, batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: batch[DataKeys.PREDS] = Task.predict_step( - self._task.task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx + self._task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx ) return batch