diff --git a/src/orion/core/worker/transformer.py b/src/orion/core/worker/transformer.py index f2b91b86c..22a293890 100644 --- a/src/orion/core/worker/transformer.py +++ b/src/orion/core/worker/transformer.py @@ -270,6 +270,10 @@ def __init__(self, categories): self._map = numpy.vectorize(lambda x: map_dict[x], otypes='i') self._imap = numpy.vectorize(lambda x: categories[x], otypes=[numpy.object]) + def __deepcopy__(self, memo): + """Make a deepcopy""" + return type(self)(self.categories) + def transform(self, point): """Return integers corresponding uniquely to the categories in `point`. diff --git a/tests/unittests/core/test_transformer.py b/tests/unittests/core/test_transformer.py index 7a7485bee..dc147a3fc 100644 --- a/tests/unittests/core/test_transformer.py +++ b/tests/unittests/core/test_transformer.py @@ -17,6 +17,12 @@ class TestIdentity(object): """Test subclasses of `Identity` transformation.""" + def test_deepcopy(self): + """Verify that the transformation object can be copied""" + t = Identity() + t.transform([2]) + copy.deepcopy(t) + def test_domain_and_target_type(self): """Check if attribute-like `domain_type` and `target_type` do what's expected. @@ -53,6 +59,12 @@ def test_repr_format(self): class TestReverse(object): """Test subclasses of `Reverse` transformation.""" + def test_deepcopy(self): + """Verify that the transformation object can be copied""" + t = Reverse(Quantize()) + t.transform([2]) + copy.deepcopy(t) + def test_domain_and_target_type(self): """Check if attribute-like `domain_type` and `target_type` do what's expected. @@ -94,6 +106,12 @@ def test_repr_format(self): class TestCompose(object): """Test subclasses of `Compose` transformation.""" + def test_deepcopy(self): + """Verify that the transformation object can be copied""" + t = Compose([Enumerate([2, 'asfa', 'ipsi']), OneHotEncode(3)], 'categorical') + t.transform([2]) + copy.deepcopy(t) + def test_domain_and_target_type(self): """Check if attribute-like `domain_type` and `target_type` do what's expected. @@ -189,6 +207,12 @@ def test_repr_format(self): class TestQuantize(object): """Test subclasses of `Quantize` transformation.""" + def test_deepcopy(self): + """Verify that the transformation object can be copied""" + t = Quantize() + t.transform([2]) + copy.deepcopy(t) + def test_domain_and_target_type(self): """Check if attribute-like `domain_type` and `target_type` do what's expected. @@ -225,6 +249,13 @@ def test_repr_format(self): class TestEnumerate(object): """Test subclasses of `Enumerate` transformation.""" + def test_deepcopy(self): + """Verify that the transformation object can be copied""" + t = Enumerate([2, 'asfa', 'ipsi']) + # Copy won't fail if vectorized function is not called at least once. + t.transform([2]) + copy.deepcopy(t) + def test_domain_and_target_type(self): """Check if attribute-like `domain_type` and `target_type` do what's expected. @@ -283,6 +314,12 @@ def test_repr_format(self): class TestOneHotEncode(object): """Test subclasses of `OneHotEncode` transformation.""" + def test_deepcopy(self): + """Verify that the transformation object can be copied""" + t = OneHotEncode(3) + t.transform([2]) + copy.deepcopy(t) + def test_domain_and_target_type(self): """Check if attribute-like `domain_type` and `target_type` do what's expected.