Skip to content

Commit

Permalink
Fix deepcopy of Enumerate transform
Browse files Browse the repository at this point in the history
Why:

The vectorized function of Enumerate could not be `deepcopy`ed making it
unusable because of naive algorithm's that requires copying the original
algo.

How:

Overwrite `__deepcopy__` the rebuild the Enumerate object instead of
copying it.
  • Loading branch information
bouthilx committed Oct 5, 2019
1 parent 31b8e1b commit 9b9944f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/orion/core/worker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,10 @@ def __init__(self, categories):
self._map = numpy.vectorize(lambda x: map_dict[x], otypes='i')
self._imap = numpy.vectorize(lambda x: categories[x], otypes=[numpy.object])

def __deepcopy__(self, memo):
"""Make a deepcopy"""
return type(self)(self.categories)

def transform(self, point):
"""Return integers corresponding uniquely to the categories in `point`.
Expand Down
37 changes: 37 additions & 0 deletions tests/unittests/core/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
class TestIdentity(object):
"""Test subclasses of `Identity` transformation."""

def test_deepcopy(self):
"""Verify that the transformation object can be copied"""
t = Identity()
t.transform([2])
copy.deepcopy(t)

def test_domain_and_target_type(self):
"""Check if attribute-like `domain_type` and `target_type` do
what's expected.
Expand Down Expand Up @@ -53,6 +59,12 @@ def test_repr_format(self):
class TestReverse(object):
"""Test subclasses of `Reverse` transformation."""

def test_deepcopy(self):
"""Verify that the transformation object can be copied"""
t = Reverse(Quantize())
t.transform([2])
copy.deepcopy(t)

def test_domain_and_target_type(self):
"""Check if attribute-like `domain_type` and `target_type` do
what's expected.
Expand Down Expand Up @@ -94,6 +106,12 @@ def test_repr_format(self):
class TestCompose(object):
"""Test subclasses of `Compose` transformation."""

def test_deepcopy(self):
"""Verify that the transformation object can be copied"""
t = Compose([Enumerate([2, 'asfa', 'ipsi']), OneHotEncode(3)], 'categorical')
t.transform([2])
copy.deepcopy(t)

def test_domain_and_target_type(self):
"""Check if attribute-like `domain_type` and `target_type` do
what's expected.
Expand Down Expand Up @@ -189,6 +207,12 @@ def test_repr_format(self):
class TestQuantize(object):
"""Test subclasses of `Quantize` transformation."""

def test_deepcopy(self):
"""Verify that the transformation object can be copied"""
t = Quantize()
t.transform([2])
copy.deepcopy(t)

def test_domain_and_target_type(self):
"""Check if attribute-like `domain_type` and `target_type` do
what's expected.
Expand Down Expand Up @@ -225,6 +249,13 @@ def test_repr_format(self):
class TestEnumerate(object):
"""Test subclasses of `Enumerate` transformation."""

def test_deepcopy(self):
"""Verify that the transformation object can be copied"""
t = Enumerate([2, 'asfa', 'ipsi'])
# Copy won't fail if vectorized function is not called at least once.
t.transform([2])
copy.deepcopy(t)

def test_domain_and_target_type(self):
"""Check if attribute-like `domain_type` and `target_type` do
what's expected.
Expand Down Expand Up @@ -283,6 +314,12 @@ def test_repr_format(self):
class TestOneHotEncode(object):
"""Test subclasses of `OneHotEncode` transformation."""

def test_deepcopy(self):
"""Verify that the transformation object can be copied"""
t = OneHotEncode(3)
t.transform([2])
copy.deepcopy(t)

def test_domain_and_target_type(self):
"""Check if attribute-like `domain_type` and `target_type` do
what's expected.
Expand Down

0 comments on commit 9b9944f

Please sign in to comment.