Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix DDP image classification #969

Merged
merged 4 commits into from
Nov 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where using image classification with DDP spawn would trigger an infinite recursion ([#969](https://github.com/PyTorchLightning/lightning-flash/pull/969))

### Removed

- Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))
Expand Down
42 changes: 12 additions & 30 deletions flash/image/classification/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,6 @@ def remap(self, data, mapping):
return data


class NoModule:

"""This class is used to prevent nn.Module infinite recursion."""

def __init__(self, task):
self.task = task

def __getattr__(self, key):
if key != "task":
return getattr(self.task, key)
return self.task

def __setattr__(self, key: str, value: Any) -> None:
if key == "task":
object.__setattr__(self, key, value)
return
setattr(self.task, key, value)


class Model(torch.nn.Module):
def __init__(self, backbone: torch.nn.Module, head: Optional[torch.nn.Module]):
super().__init__()
Expand All @@ -97,7 +78,6 @@ class Learn2LearnAdapter(Adapter):

def __init__(
self,
task: AdapterTask,
backbone: torch.nn.Module,
head: torch.nn.Module,
algorithm_cls: Type[LightningModule],
Expand Down Expand Up @@ -143,7 +123,6 @@ def __init__(

super().__init__()

self._task = NoModule(task)
self.backbone = backbone
self.head = head
self.algorithm_cls = algorithm_cls
Expand Down Expand Up @@ -309,7 +288,9 @@ def from_task(
"The `shots` should be provided training_strategy_kwargs={'shots'=...}. "
"This is equivalent to the number of sample per label to select within a task."
)
return cls(task, backbone, head, algorithm, **kwargs)
adapter = cls(backbone, head, algorithm, **kwargs)
adapter.__dict__["_task"] = task
return adapter

def training_step(self, batch, batch_idx) -> Any:
input = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
Expand Down Expand Up @@ -337,7 +318,7 @@ def _sanetize_batch_size(self, batch_size: int) -> int:
warning_cache.warn(
"When using a meta-learning training_strategy, the batch_size should be set to 1. "
"HINT: You can modify the `meta_batch_size` to 100 for example by doing "
f"{type(self._task.task)}" + "(training_strategies_kwargs={'meta_batch_size': 100})"
f"{type(self._task)}" + "(training_strategies_kwargs={'meta_batch_size': 100})"
)
return 1

Expand Down Expand Up @@ -486,10 +467,9 @@ class DefaultAdapter(Adapter):

required_extras: str = "image"

def __init__(self, task: AdapterTask, backbone: torch.nn.Module, head: torch.nn.Module):
def __init__(self, backbone: torch.nn.Module, head: torch.nn.Module):
super().__init__()

self._task = NoModule(task)
self.backbone = backbone
self.head = head

Expand All @@ -503,23 +483,25 @@ def from_task(
head: torch.nn.Module,
**kwargs,
) -> Adapter:
return cls(task, backbone, head)
adapter = cls(backbone, head)
adapter.__dict__["_task"] = task
return adapter

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
return Task.training_step(self._task.task, batch, batch_idx)
return Task.training_step(self._task, batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
return Task.validation_step(self._task.task, batch, batch_idx)
return Task.validation_step(self._task, batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
return Task.test_step(self._task.task, batch, batch_idx)
return Task.test_step(self._task, batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch[DataKeys.PREDS] = Task.predict_step(
self._task.task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx
self._task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx
)
return batch

Expand Down
Empty file.