diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index 37b11561d1e1..802c699eb9cc 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -16,7 +16,7 @@ """ from collections import OrderedDict -from dataclasses import fields +from dataclasses import fields, is_dataclass from typing import Any, Tuple import numpy as np @@ -101,6 +101,13 @@ def __setitem__(self, key, value): # Don't call self.__setattr__ to avoid recursion errors super().__setattr__(key, value) + def __reduce__(self): + if not is_dataclass(self): + return super().__reduce__() + callable, _args, *remaining = super().__reduce__() + args = tuple(getattr(self, field.name) for field in fields(self)) + return callable, args, *remaining + def to_tuple(self) -> Tuple[Any]: """ Convert self to a tuple containing all the attributes/keys that are not `None`. diff --git a/tests/others/test_outputs.py b/tests/others/test_outputs.py index 50cbd1d54ee4..492e71f0ba31 100644 --- a/tests/others/test_outputs.py +++ b/tests/others/test_outputs.py @@ -1,3 +1,4 @@ +import pickle as pkl import unittest from dataclasses import dataclass from typing import List, Union @@ -58,3 +59,13 @@ def test_outputs_dict_init(self): assert isinstance(outputs["images"][0], PIL.Image.Image) assert isinstance(outputs[0], list) assert isinstance(outputs[0][0], PIL.Image.Image) + + def test_outputs_serialization(self): + outputs_orig = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))]) + serialized = pkl.dumps(outputs_orig) + outputs_copy = pkl.loads(serialized) + + # Check original and copy are equal + assert dir(outputs_orig) == dir(outputs_copy) + assert dict(outputs_orig) == dict(outputs_copy) + assert vars(outputs_orig) == vars(outputs_copy)