diff --git a/CHANGELOG.md b/CHANGELOG.md index 8505006e41..167a9aca6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed a bug where using image classification with DDP spawn would trigger an infinite recursion ([#969](https://github.com/PyTorchLightning/lightning-flash/pull/969)) + ### Removed - Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939)) diff --git a/flash/image/classification/adapters.py b/flash/image/classification/adapters.py index 9a5faeb092..8ed89d6fcc 100644 --- a/flash/image/classification/adapters.py +++ b/flash/image/classification/adapters.py @@ -57,25 +57,6 @@ def remap(self, data, mapping): return data -class NoModule: - - """This class is used to prevent nn.Module infinite recursion.""" - - def __init__(self, task): - self.task = task - - def __getattr__(self, key): - if key != "task": - return getattr(self.task, key) - return self.task - - def __setattr__(self, key: str, value: Any) -> None: - if key == "task": - object.__setattr__(self, key, value) - return - setattr(self.task, key, value) - - class Model(torch.nn.Module): def __init__(self, backbone: torch.nn.Module, head: Optional[torch.nn.Module]): super().__init__() @@ -97,7 +78,6 @@ class Learn2LearnAdapter(Adapter): def __init__( self, - task: AdapterTask, backbone: torch.nn.Module, head: torch.nn.Module, algorithm_cls: Type[LightningModule], @@ -143,7 +123,6 @@ def __init__( super().__init__() - self._task = NoModule(task) self.backbone = backbone self.head = head self.algorithm_cls = algorithm_cls @@ -309,7 +288,9 @@ def from_task( "The `shots` should be provided training_strategy_kwargs={'shots'=...}. " "This is equivalent to the number of sample per label to select within a task." ) - return cls(task, backbone, head, algorithm, **kwargs) + adapter = cls(backbone, head, algorithm, **kwargs) + adapter.__dict__["_task"] = task + return adapter def training_step(self, batch, batch_idx) -> Any: input = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) @@ -337,7 +318,7 @@ def _sanetize_batch_size(self, batch_size: int) -> int: warning_cache.warn( "When using a meta-learning training_strategy, the batch_size should be set to 1. " "HINT: You can modify the `meta_batch_size` to 100 for example by doing " - f"{type(self._task.task)}" + "(training_strategies_kwargs={'meta_batch_size': 100})" + f"{type(self._task)}" + "(training_strategies_kwargs={'meta_batch_size': 100})" ) return 1 @@ -486,10 +467,9 @@ class DefaultAdapter(Adapter): required_extras: str = "image" - def __init__(self, task: AdapterTask, backbone: torch.nn.Module, head: torch.nn.Module): + def __init__(self, backbone: torch.nn.Module, head: torch.nn.Module): super().__init__() - self._task = NoModule(task) self.backbone = backbone self.head = head @@ -503,23 +483,25 @@ def from_task( head: torch.nn.Module, **kwargs, ) -> Adapter: - return cls(task, backbone, head) + adapter = cls(backbone, head) + adapter.__dict__["_task"] = task + return adapter def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) - return Task.training_step(self._task.task, batch, batch_idx) + return Task.training_step(self._task, batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) - return Task.validation_step(self._task.task, batch, batch_idx) + return Task.validation_step(self._task, batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET]) - return Task.test_step(self._task.task, batch, batch_idx) + return Task.test_step(self._task, batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: batch[DataKeys.PREDS] = Task.predict_step( - self._task.task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx + self._task, (batch[DataKeys.INPUT]), batch_idx, dataloader_idx=dataloader_idx ) return batch diff --git a/flash/pointcloud/detection/open3d_ml/__init__.py b/flash/pointcloud/detection/open3d_ml/__init__.py new file mode 100644 index 0000000000..e69de29bb2