Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return onehot interval instead of categorical #447

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/orion/core/worker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ def reverse(self, transformed_point):
transformed_point = self.apply.reverse(transformed_point)
return self.composition.reverse(transformed_point)

def interval(self, alpha=1.0):
"""Return interval of composed transformation."""
if hasattr(self.apply, 'interval'):
return self.apply.interval(alpha)

return None

def infer_target_shape(self, shape):
"""Return the shape of the dimension after transformation."""
shape = self.composition.infer_target_shape(shape)
Expand Down Expand Up @@ -370,6 +377,14 @@ def reverse(self, transformed_point):
assert point_.shape[-1] == self.num_cats
return point_.argmax(axis=-1)

# pylint:disable=unused-argument
def interval(self, alpha=1.0):
"""Return the interval for the one-hot encoding in proper shape."""
low = numpy.zeros(self.num_cats)
high = numpy.ones(self.num_cats)

return low, high

def infer_target_shape(self, shape):
"""Infer that transformed points will have one more tensor dimension,
if the number of supported integers to transform is larger than 2.
Expand Down Expand Up @@ -411,7 +426,11 @@ def sample(self, n_samples=1, seed=None):

def interval(self, alpha=1.0):
"""Map the interval bounds to the transformed ones."""
if self.original_dimension.prior_name == 'choices':
if hasattr(self.transformer, 'interval'):
interval = self.transformer.interval()
if interval:
return interval
elif self.original_dimension.prior_name == 'choices':
return self.original_dimension.categories

low, high = self.original_dimension.interval(alpha)
Expand Down
11 changes: 10 additions & 1 deletion tests/unittests/core/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,13 @@ def test_reverse(self):
assert numpy.all(t.reverse([[0.5, 0], [1.0, 55]]) == numpy.array([[0, 0], [0, 0]],
dtype=int))

def test_interval(self):
"""Test that the onehot interval has the proper dimensions"""
t = OneHotEncode(3)
low, high = t.interval()
assert (low == numpy.zeros(3)).all()
assert (high == numpy.ones(3)).all()

def test_infer_target_shape(self):
"""Check if it infers the shape of a transformed `Dimension`."""
t = OneHotEncode(3)
Expand Down Expand Up @@ -531,7 +538,9 @@ def test_interval(self, tdim):

def test_interval_from_categorical(self, tdim2):
"""Check how we should treat interval when original dimension is categorical."""
assert tdim2.interval() == ('asdfa', '2', '3', '4')
low, high = tdim2.interval()
assert (low == numpy.zeros(4)).all()
assert (high == numpy.ones(4)).all()

def test_contains(self, tdim):
"""Check method `__contains__`."""
Expand Down