diff --git a/flash/core/data/io/input_base.py b/flash/core/data/io/input_base.py index 01616bef69..94ef607f49 100644 --- a/flash/core/data/io/input_base.py +++ b/flash/core/data/io/input_base.py @@ -191,6 +191,14 @@ def __setstate__(self, newstate): newstate["data"] = None self.__dict__.update(newstate) + def __copy__(self): + """The default copy implementation seems to use ``__getstate__`` and ``__setstate__`` so we override it + here with a custom implementation to ensure that it includes the data list.""" + cls = self.__class__ + result = cls.__new__(cls) + result.__dict__.update(self.__dict__) + return result + def __deepcopy__(self, memo): """The default deepcopy implementation seems to use ``__getstate__`` and ``__setstate__`` so we override it here with a custom implementation to ensure that it includes the data list."""