Skip to content

Commit

Permalink
feat(common): make Slotted and FrozenSlotted pickleable
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Oct 9, 2023
1 parent 0c60146 commit 13cbce0
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
13 changes: 13 additions & 0 deletions ibis/common/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,13 @@ def __eq__(self, other) -> bool:
return NotImplemented
return all(getattr(self, n) == getattr(other, n) for n in self.__slots__)

def __getstate__(self):
return {k: getattr(self, k) for k in self.__slots__}

def __setstate__(self, state):
for name, value in state.items():
object.__setattr__(self, name, value)

def __repr__(self):
fields = {k: getattr(self, k) for k in self.__slots__}
fieldstring = ", ".join(f"{k}={v!r}" for k, v in fields.items())
Expand All @@ -221,5 +228,11 @@ def __init__(self, **kwargs) -> None:
hashvalue = hash(tuple(kwargs.values()))
object.__setattr__(self, "__precomputed_hash__", hashvalue)

def __setstate__(self, state):
for name, value in state.items():
object.__setattr__(self, name, value)
hashvalue = hash(tuple(state.values()))
object.__setattr__(self, "__precomputed_hash__", hashvalue)

def __hash__(self) -> int:
return self.__precomputed_hash__
55 changes: 55 additions & 0 deletions ibis/common/tests/test_bases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import pickle
import weakref
from abc import ABCMeta, abstractmethod

Expand All @@ -11,8 +12,10 @@
AbstractMeta,
Comparable,
Final,
FrozenSlotted,
Immutable,
Singleton,
Slotted,
)
from ibis.common.caching import WeakCache

Expand Down Expand Up @@ -258,3 +261,55 @@ class A(Final):

class B(A):
pass


class MyObj(Slotted):
__slots__ = ("a", "b")

def __init__(self, a, b):
super().__init__(a=a, b=b)


def test_slotted():
obj = MyObj(1, 2)
assert obj.a == 1
assert obj.b == 2
assert obj.__slots__ == ("a", "b")
with pytest.raises(AttributeError):
obj.c = 3

obj2 = MyObj(1, 2)
assert obj == obj2
assert obj is not obj2

obj3 = MyObj(1, 3)
assert obj != obj3

assert pickle.loads(pickle.dumps(obj)) == obj


class MyFrozenObj(FrozenSlotted):
__slots__ = ("a", "b")

def __init__(self, a, b):
super().__init__(a=a, b=b)


def test_frozen_slotted():
obj = MyFrozenObj(1, 2)
assert obj.a == 1
assert obj.b == 2
assert obj.__slots__ == ("a", "b")
with pytest.raises(AttributeError):
obj.b = 3
with pytest.raises(AttributeError):
obj.c = 3

obj2 = MyFrozenObj(1, 2)
assert obj == obj2
assert obj is not obj2
assert hash(obj) == hash(obj2)

restored = pickle.loads(pickle.dumps(obj))
assert restored == obj
assert hash(restored) == hash(obj)

0 comments on commit 13cbce0

Please sign in to comment.