diff --git a/src/orion/core/worker/transformer.py b/src/orion/core/worker/transformer.py index 75fa7fd48..991bc6188 100644 --- a/src/orion/core/worker/transformer.py +++ b/src/orion/core/worker/transformer.py @@ -183,6 +183,13 @@ def reverse(self, transformed_point): transformed_point = self.apply.reverse(transformed_point) return self.composition.reverse(transformed_point) + def interval(self, alpha=1.0): + """Return interval of composed transformation.""" + if hasattr(self.apply, 'interval'): + return self.apply.interval(alpha) + + return None + def infer_target_shape(self, shape): """Return the shape of the dimension after transformation.""" shape = self.composition.infer_target_shape(shape) @@ -370,6 +377,14 @@ def reverse(self, transformed_point): assert point_.shape[-1] == self.num_cats return point_.argmax(axis=-1) + # pylint:disable=unused-argument + def interval(self, alpha=1.0): + """Return the interval for the one-hot encoding in proper shape.""" + low = numpy.zeros(self.num_cats) + high = numpy.ones(self.num_cats) + + return low, high + def infer_target_shape(self, shape): """Infer that transformed points will have one more tensor dimension, if the number of supported integers to transform is larger than 2. @@ -411,7 +426,11 @@ def sample(self, n_samples=1, seed=None): def interval(self, alpha=1.0): """Map the interval bounds to the transformed ones.""" - if self.original_dimension.prior_name == 'choices': + if hasattr(self.transformer, 'interval'): + interval = self.transformer.interval() + if interval: + return interval + elif self.original_dimension.prior_name == 'choices': return self.original_dimension.categories low, high = self.original_dimension.interval(alpha) diff --git a/tests/unittests/core/test_transformer.py b/tests/unittests/core/test_transformer.py index 0bb2b6955..c00d1e86f 100644 --- a/tests/unittests/core/test_transformer.py +++ b/tests/unittests/core/test_transformer.py @@ -430,6 +430,13 @@ def test_reverse(self): assert numpy.all(t.reverse([[0.5, 0], [1.0, 55]]) == numpy.array([[0, 0], [0, 0]], dtype=int)) + def test_interval(self): + """Test that the onehot interval has the proper dimensions""" + t = OneHotEncode(3) + low, high = t.interval() + assert (low == numpy.zeros(3)).all() + assert (high == numpy.ones(3)).all() + def test_infer_target_shape(self): """Check if it infers the shape of a transformed `Dimension`.""" t = OneHotEncode(3) @@ -531,7 +538,9 @@ def test_interval(self, tdim): def test_interval_from_categorical(self, tdim2): """Check how we should treat interval when original dimension is categorical.""" - assert tdim2.interval() == ('asdfa', '2', '3', '4') + low, high = tdim2.interval() + assert (low == numpy.zeros(4)).all() + assert (high == numpy.ones(4)).all() def test_contains(self, tdim): """Check method `__contains__`."""