Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Remove unused Output saving mechanism #948

Merged
merged 2 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed `Output.enable` and `Output.disable` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))

- Removed `OutputTransform.save_sample` and `save_data` hooks ([#948](https://github.com/PyTorchLightning/lightning-flash/pull/948))

## [0.5.2] - 2021-11-05

Expand Down
16 changes: 0 additions & 16 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,34 +441,18 @@ def _create_output_transform_processor(
stage: RunningStage,
is_serving: bool = False,
) -> _OutputTransformProcessor:
save_per_sample = None
save_fn = None

output_transform: OutputTransform = self._output_transform

func_names: Dict[str, str] = {
k: self._resolve_function_hierarchy(k, output_transform, stage, object_type=OutputTransform)
for k in self.OUTPUT_TRANSFORM_FUNCS
}

# since postprocessing is exclusive for prediction, we don't have to check the resolution hierarchy here.
if output_transform._save_path:
save_per_sample: bool = self._is_overriden_recursive(
"save_sample", output_transform, OutputTransform, prefix=_STAGES_PREFIX[stage]
)

if save_per_sample:
save_per_sample: Callable = getattr(output_transform, func_names["save_sample"])
else:
save_fn: Callable = getattr(output_transform, func_names["save_data"])

return _OutputTransformProcessor(
getattr(output_transform, func_names["uncollate"]),
getattr(output_transform, func_names["per_batch_transform"]),
getattr(output_transform, func_names["per_sample_transform"]),
output=None if is_serving else self._output,
save_fn=save_fn,
save_per_sample=save_per_sample,
is_serving=is_serving,
)

Expand Down
47 changes: 2 additions & 45 deletions flash/core/data/io/output_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple

import torch
Expand All @@ -27,11 +26,6 @@ class OutputTransform(Properties):
"""The :class:`~flash.core.data.io.output_transform.OutputTransform` encapsulates all the data processing logic
that should run after the model."""

def __init__(self, save_path: Optional[str] = None):
super().__init__()
self._saved_samples = 0
self._save_path = save_path

@staticmethod
def per_batch_transform(batch: Any) -> Any:
"""Transforms to apply on a whole batch before uncollation to individual samples.
Expand All @@ -56,28 +50,6 @@ def uncollate(batch: Any) -> Any:
"""
return default_uncollate(batch)

@staticmethod
def save_data(data: Any, path: str) -> None:
"""Saves all data together to a single path."""
torch.save(data, path)

@staticmethod
def save_sample(sample: Any, path: str) -> None:
"""Saves each sample individually to a given path."""
torch.save(sample, path)

# TODO: Are those needed ?
def format_sample_save_path(self, path: str) -> str:
path = os.path.join(path, f"sample_{self._saved_samples}.ptl")
self._saved_samples += 1
return path

def _save_data(self, data: Any) -> None:
self.save_data(data, self._save_path)

def _save_sample(self, sample: Any) -> None:
self.save_sample(sample, self.format_sample_save_path(self._save_path))


class _OutputTransformProcessor(torch.nn.Module):
"""This class is used to encapsultate the following functions of a OutputTransform Object:
Expand All @@ -87,8 +59,6 @@ class _OutputTransformProcessor(torch.nn.Module):
per_sample_transform: Function to transform an individual sample
uncollate_fn: Function to split a batch into samples
per_sample_transform: Function to transform an individual sample
save_fn: Function to save all data
save_per_sample: Function to save an individual sample
is_serving: Whether the Postprocessor is used in serving mode.
"""

Expand All @@ -98,17 +68,13 @@ def __init__(
per_batch_transform: Callable,
per_sample_transform: Callable,
output: Optional[Callable],
save_fn: Optional[Callable] = None,
save_per_sample: bool = False,
is_serving: bool = False,
):
super().__init__()
self.uncollate_fn = convert_to_modules(uncollate_fn)
self.per_batch_transform = convert_to_modules(per_batch_transform)
self.per_sample_transform = convert_to_modules(per_sample_transform)
self.output = convert_to_modules(output)
self.save_fn = convert_to_modules(save_fn)
self.save_per_sample = convert_to_modules(save_per_sample)
self.is_serving = is_serving

@staticmethod
Expand All @@ -131,17 +97,8 @@ def forward(self, batch: Sequence[Any]):
final_preds = [self.output(sample) for sample in final_preds]

if isinstance(uncollated, Tensor) and isinstance(final_preds[0], Tensor):
final_preds = torch.stack(final_preds)
else:
final_preds = type(final_preds)(final_preds)

if self.save_fn:
if self.save_per_sample:
for pred in final_preds:
self.save_fn(pred)
else:
self.save_fn(final_preds)
return final_preds
return torch.stack(final_preds)
return type(final_preds)(final_preds)

def __str__(self) -> str:
return (
Expand Down
2 changes: 0 additions & 2 deletions flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@
"per_batch_transform",
"uncollate",
"per_sample_transform",
"save_sample",
"save_data",
}


Expand Down