Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Try fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Nov 15, 2021
1 parent 1df3d9c commit 4665310
Showing 1 changed file with 12 additions and 30 deletions.
42 changes: 12 additions & 30 deletions flash/image/classification/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,6 @@ def remap(self, data, mapping):
return data


class NoModule:

"""This class is used to prevent nn.Module infinite recursion."""

def __init__(self, task):
self.task = task

def __getattr__(self, key):
if key != "task":
return getattr(self.task, key)
return self.task

def __setattr__(self, key: str, value: Any) -> None:
if key == "task":
object.__setattr__(self, key, value)
return
setattr(self.task, key, value)


class Model(torch.nn.Module):
def __init__(self, backbone: torch.nn.Module, head: Optional[torch.nn.Module]):
super().__init__()
Expand All @@ -97,7 +78,6 @@ class Learn2LearnAdapter(Adapter):

def __init__(
self,
task: AdapterTask,
backbone: torch.nn.Module,
head: torch.nn.Module,
algorithm_cls: Type[LightningModule],
Expand Down Expand Up @@ -143,7 +123,6 @@ def __init__(

super().__init__()

self._task = NoModule(task)
self.backbone = backbone
self.head = head
self.algorithm_cls = algorithm_cls
Expand Down Expand Up @@ -309,7 +288,9 @@ def from_task(
"The `shots` should be provided training_strategy_kwargs={'shots'=...}. "
"This is equivalent to the number of sample per label to select within a task."
)
return cls(task, backbone, head, algorithm, **kwargs)
adapter = cls(backbone, head, algorithm, **kwargs)
adapter["_task"] = task
return adapter

def training_step(self, batch, batch_idx) -> Any:
input = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
Expand Down Expand Up @@ -337,7 +318,7 @@ def _sanetize_batch_size(self, batch_size: int) -> int:
warning_cache.warn(
"When using a meta-learning training_strategy, the batch_size should be set to 1. "
"HINT: You can modify the `meta_batch_size` to 100 for example by doing "
f"{type(self._task.task)}" + "(training_strategies_kwargs={'meta_batch_size': 100})"
f"{type(self._task)}" + "(training_strategies_kwargs={'meta_batch_size': 100})"
)
return 1

Expand Down Expand Up @@ -486,10 +467,9 @@ class DefaultAdapter(Adapter):

required_extras: str = "image"

def __init__(self, task: AdapterTask, backbone: torch.nn.Module, head: torch.nn.Module):
def __init__(self, backbone: torch.nn.Module, head: torch.nn.Module):
super().__init__()

self._task = NoModule(task)
self.backbone = backbone
self.head = head

Expand All @@ -503,23 +483,25 @@ def from_task(
head: torch.nn.Module,
**kwargs,
) -> Adapter:
return cls(task, backbone, head)
adapter = cls(backbone, head)
adapter.__dict__["_task"] = task
return adapter

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
return Task.training_step(self._task.task, batch, batch_idx)
return Task.training_step(self._task, batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
return Task.validation_step(self._task.task, batch, batch_idx)
return Task.validation_step(self._task, batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
return Task.test_step(self._task.task, batch, batch_idx)
return Task.test_step(self._task, batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch[DataKeys.PREDS] = Task.predict_step(
self._task.task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx
self._task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx
)
return batch

Expand Down

0 comments on commit 4665310

Please sign in to comment.