From 7ea196b2017ddd04be70413bb28ca7fa2065cbe6 Mon Sep 17 00:00:00 2001 From: Xavier Bouthillier Date: Wed, 26 Aug 2020 15:07:21 -0400 Subject: [PATCH] Return onehot interval instead of categorical Why: When the categorical dimension is transformed to one-hot, the interval returned should be in the space of the one-hot otherwise the algorithm requesting a real space cannot handle the categories. How: Add an interval method to OneHot Transformer. TransformerDimension will look for an interval method, otherwise use the default interval of original dimension. Note that CompositeTransformer push the interval() method recursively if necessary, so that a one-hot encoding that is part of a composite transformer can still provide its interval(). --- src/orion/core/worker/transformer.py | 21 ++++++++++++++++++++- tests/unittests/core/test_transformer.py | 11 ++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/orion/core/worker/transformer.py b/src/orion/core/worker/transformer.py index 75fa7fd48..991bc6188 100644 --- a/src/orion/core/worker/transformer.py +++ b/src/orion/core/worker/transformer.py @@ -183,6 +183,13 @@ def reverse(self, transformed_point): transformed_point = self.apply.reverse(transformed_point) return self.composition.reverse(transformed_point) + def interval(self, alpha=1.0): + """Return interval of composed transformation.""" + if hasattr(self.apply, 'interval'): + return self.apply.interval(alpha) + + return None + def infer_target_shape(self, shape): """Return the shape of the dimension after transformation.""" shape = self.composition.infer_target_shape(shape) @@ -370,6 +377,14 @@ def reverse(self, transformed_point): assert point_.shape[-1] == self.num_cats return point_.argmax(axis=-1) + # pylint:disable=unused-argument + def interval(self, alpha=1.0): + """Return the interval for the one-hot encoding in proper shape.""" + low = numpy.zeros(self.num_cats) + high = numpy.ones(self.num_cats) + + return low, high + def infer_target_shape(self, shape): """Infer that transformed points will have one more tensor dimension, if the number of supported integers to transform is larger than 2. @@ -411,7 +426,11 @@ def sample(self, n_samples=1, seed=None): def interval(self, alpha=1.0): """Map the interval bounds to the transformed ones.""" - if self.original_dimension.prior_name == 'choices': + if hasattr(self.transformer, 'interval'): + interval = self.transformer.interval() + if interval: + return interval + elif self.original_dimension.prior_name == 'choices': return self.original_dimension.categories low, high = self.original_dimension.interval(alpha) diff --git a/tests/unittests/core/test_transformer.py b/tests/unittests/core/test_transformer.py index 0bb2b6955..c00d1e86f 100644 --- a/tests/unittests/core/test_transformer.py +++ b/tests/unittests/core/test_transformer.py @@ -430,6 +430,13 @@ def test_reverse(self): assert numpy.all(t.reverse([[0.5, 0], [1.0, 55]]) == numpy.array([[0, 0], [0, 0]], dtype=int)) + def test_interval(self): + """Test that the onehot interval has the proper dimensions""" + t = OneHotEncode(3) + low, high = t.interval() + assert (low == numpy.zeros(3)).all() + assert (high == numpy.ones(3)).all() + def test_infer_target_shape(self): """Check if it infers the shape of a transformed `Dimension`.""" t = OneHotEncode(3) @@ -531,7 +538,9 @@ def test_interval(self, tdim): def test_interval_from_categorical(self, tdim2): """Check how we should treat interval when original dimension is categorical.""" - assert tdim2.interval() == ('asdfa', '2', '3', '4') + low, high = tdim2.interval() + assert (low == numpy.zeros(4)).all() + assert (high == numpy.ones(4)).all() def test_contains(self, tdim): """Check method `__contains__`."""