Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Improve memory usage (#1448)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Sep 5, 2022
1 parent 95ae65f commit 31e76b9
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions flash/core/data/io/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import functools
import os
import sys
from copy import deepcopy
from enum import Enum
from typing import Any, cast, Dict, Iterable, List, Sequence, Tuple, Union

Expand All @@ -40,6 +39,13 @@
IterableDataset = object


def _deepcopy_dict(nested_dict: Any) -> Any:
"""Utility to deepcopy a nested dict."""
if not isinstance(nested_dict, Dict):
return nested_dict
return {key: value for key, value in nested_dict.items()}


class InputFormat(LightningEnum):
"""The ``InputFormat`` enum contains the data source names used by all of the default ``from_*`` methods in
:class:`~flash.core.data.data_module.DataModule`."""
Expand Down Expand Up @@ -172,7 +178,7 @@ def __init__(self, running_stage: RunningStage, *args: Any, **kwargs: Any) -> No

def _call_load_sample(self, sample: Any) -> Any:
# Deepcopy the sample to avoid leaks with complex data structures
sample_output = getattr(self, f"{_STAGES_PREFIX[self.running_stage]}_load_sample")(deepcopy(sample))
sample_output = getattr(self, f"{_STAGES_PREFIX[self.running_stage]}_load_sample")(_deepcopy_dict(sample))

# Change DataKeys Enum to strings
if isinstance(sample_output, dict):
Expand Down

0 comments on commit 31e76b9

Please sign in to comment.