Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Remove unused Output saving mechanism (#948)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Nov 8, 2021
1 parent ad64bc7 commit 7634dde
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 63 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed `Output.enable` and `Output.disable` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))

- Removed `OutputTransform.save_sample` and `save_data` hooks ([#948](https://github.com/PyTorchLightning/lightning-flash/pull/948))

## [0.5.2] - 2021-11-05

Expand Down
16 changes: 0 additions & 16 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,34 +441,18 @@ def _create_output_transform_processor(
stage: RunningStage,
is_serving: bool = False,
) -> _OutputTransformProcessor:
save_per_sample = None
save_fn = None

output_transform: OutputTransform = self._output_transform

func_names: Dict[str, str] = {
k: self._resolve_function_hierarchy(k, output_transform, stage, object_type=OutputTransform)
for k in self.OUTPUT_TRANSFORM_FUNCS
}

# since postprocessing is exclusive for prediction, we don't have to check the resolution hierarchy here.
if output_transform._save_path:
save_per_sample: bool = self._is_overriden_recursive(
"save_sample", output_transform, OutputTransform, prefix=_STAGES_PREFIX[stage]
)

if save_per_sample:
save_per_sample: Callable = getattr(output_transform, func_names["save_sample"])
else:
save_fn: Callable = getattr(output_transform, func_names["save_data"])

return _OutputTransformProcessor(
getattr(output_transform, func_names["uncollate"]),
getattr(output_transform, func_names["per_batch_transform"]),
getattr(output_transform, func_names["per_sample_transform"]),
output=None if is_serving else self._output,
save_fn=save_fn,
save_per_sample=save_per_sample,
is_serving=is_serving,
)

Expand Down
47 changes: 2 additions & 45 deletions flash/core/data/io/output_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple

import torch
Expand All @@ -27,11 +26,6 @@ class OutputTransform(Properties):
"""The :class:`~flash.core.data.io.output_transform.OutputTransform` encapsulates all the data processing logic
that should run after the model."""

def __init__(self, save_path: Optional[str] = None):
super().__init__()
self._saved_samples = 0
self._save_path = save_path

@staticmethod
def per_batch_transform(batch: Any) -> Any:
"""Transforms to apply on a whole batch before uncollation to individual samples.
Expand All @@ -56,28 +50,6 @@ def uncollate(batch: Any) -> Any:
"""
return default_uncollate(batch)

@staticmethod
def save_data(data: Any, path: str) -> None:
"""Saves all data together to a single path."""
torch.save(data, path)

@staticmethod
def save_sample(sample: Any, path: str) -> None:
"""Saves each sample individually to a given path."""
torch.save(sample, path)

# TODO: Are those needed ?
def format_sample_save_path(self, path: str) -> str:
path = os.path.join(path, f"sample_{self._saved_samples}.ptl")
self._saved_samples += 1
return path

def _save_data(self, data: Any) -> None:
self.save_data(data, self._save_path)

def _save_sample(self, sample: Any) -> None:
self.save_sample(sample, self.format_sample_save_path(self._save_path))


class _OutputTransformProcessor(torch.nn.Module):
"""This class is used to encapsultate the following functions of a OutputTransform Object:
Expand All @@ -87,8 +59,6 @@ class _OutputTransformProcessor(torch.nn.Module):
per_sample_transform: Function to transform an individual sample
uncollate_fn: Function to split a batch into samples
per_sample_transform: Function to transform an individual sample
save_fn: Function to save all data
save_per_sample: Function to save an individual sample
is_serving: Whether the Postprocessor is used in serving mode.
"""

Expand All @@ -98,17 +68,13 @@ def __init__(
per_batch_transform: Callable,
per_sample_transform: Callable,
output: Optional[Callable],
save_fn: Optional[Callable] = None,
save_per_sample: bool = False,
is_serving: bool = False,
):
super().__init__()
self.uncollate_fn = convert_to_modules(uncollate_fn)
self.per_batch_transform = convert_to_modules(per_batch_transform)
self.per_sample_transform = convert_to_modules(per_sample_transform)
self.output = convert_to_modules(output)
self.save_fn = convert_to_modules(save_fn)
self.save_per_sample = convert_to_modules(save_per_sample)
self.is_serving = is_serving

@staticmethod
Expand All @@ -131,17 +97,8 @@ def forward(self, batch: Sequence[Any]):
final_preds = [self.output(sample) for sample in final_preds]

if isinstance(uncollated, Tensor) and isinstance(final_preds[0], Tensor):
final_preds = torch.stack(final_preds)
else:
final_preds = type(final_preds)(final_preds)

if self.save_fn:
if self.save_per_sample:
for pred in final_preds:
self.save_fn(pred)
else:
self.save_fn(final_preds)
return final_preds
return torch.stack(final_preds)
return type(final_preds)(final_preds)

def __str__(self) -> str:
return (
Expand Down
2 changes: 0 additions & 2 deletions flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@
"per_batch_transform",
"uncollate",
"per_sample_transform",
"save_sample",
"save_data",
}


Expand Down

0 comments on commit 7634dde

Please sign in to comment.