Skip to content

Commit

Permalink
Merge pull request #267 from bouthilx/feature/asha_opt_out
Browse files Browse the repository at this point in the history
 Make ASHA opt out when last rungs are filled
  • Loading branch information
bouthilx authored Aug 27, 2019
2 parents a2a242f + a2c420e commit d206aa8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
21 changes: 18 additions & 3 deletions src/orion/algo/asha.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def suggest(self, num=1):
logger.debug('Promoting')
return [candidate]

if all(bracket.is_done for bracket in self.brackets):
logger.debug('All brackets are filled.')
return None

for _attempt in range(100):
point = list(self.space.sample(1, seed=tuple(self.rng.randint(0, 1000000, size=3)))[0])
if self.get_id(point) not in self.trial_info:
Expand All @@ -136,6 +140,8 @@ def suggest(self, num=1):

sizes = numpy.array([len(b.rungs) for b in self.brackets])
probs = numpy.e**(sizes - sizes.max())
probs = numpy.array([prob * int(not bracket.is_done)
for prob, bracket in zip(probs, self.brackets)])
normalized = probs / probs.sum()
idx = self.rng.choice(len(self.brackets), p=normalized)

Expand Down Expand Up @@ -256,13 +262,19 @@ def get_candidate(self, rung_id):

@property
def is_done(self):
"""Return True, if reached the bracket reached its maximum resources."""
return len(self.rungs[-1][1])
"""Return True, if the penultimate rung is filled."""
return self.is_filled(len(self.rungs) - 2) or len(self.rungs[-1][1])

def is_filled(self, rung_id):
"""Return True, if the rung[rung_id] is filled."""
n_rungs = len(self.rungs)
n_trials = len(self.rungs[rung_id][1])
return n_trials >= (n_rungs - rung_id - 1) ** self.reduction_factor

def update_rungs(self):
"""Promote the first candidate that is found and return it
The rungs are iterated over is reversed order, so that high rungs
The rungs are iterated over in reversed order, so that high rungs
are prioritised for promotions. When a candidate is promoted, the loop is broken and
the method returns the promoted point.
Expand All @@ -273,6 +285,9 @@ def update_rungs(self):
Lookup for promotion in rung l + 1 contains trials of any status.
"""
if self.is_done and self.rungs[-1][1]:
return None

for rung_id in range(len(self.rungs) - 2, -1, -1):
candidate = self.get_candidate(rung_id)
if candidate:
Expand Down
26 changes: 13 additions & 13 deletions tests/unittests/algo/test_asha.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,6 @@ def test_suggest_new(self, monkeypatch, asha, bracket, rung_0, rung_1, rung_2):
"""Test that a new point is sampled."""
asha.brackets = [bracket]
bracket.asha = asha
bracket.rungs[0] = rung_0
bracket.rungs[1] = rung_1
bracket.rungs[2] = rung_2

def sample(num=1, seed=None):
return [('fidelity', 0.5)]
Expand All @@ -356,14 +353,12 @@ def test_suggest_duplicates(self, monkeypatch, asha, bracket, rung_0, rung_1, ru
asha.brackets = [bracket]
bracket.asha = asha

# Fill rungs to force sampling
bracket.rungs[0] = rung_0
bracket.rungs[1] = rung_1
bracket.rungs[2] = rung_2

duplicate_point = ('fidelity', 0.0)
new_point = ('fidelity', 0.5)

duplicate_id = hashlib.md5(str([duplicate_point]).encode('utf-8')).hexdigest()
bracket.rungs[0] = (1, {duplicate_id: (0.0, duplicate_point)})

asha.trial_info[asha.get_id(duplicate_point)] = bracket

points = [duplicate_point, new_point]
Expand All @@ -381,11 +376,6 @@ def test_suggest_inf_duplicates(self, monkeypatch, asha, bracket, rung_0, rung_1
asha.brackets = [bracket]
bracket.asha = asha

# Fill rungs to force sampling
bracket.rungs[0] = rung_0
bracket.rungs[1] = rung_1
bracket.rungs[2] = rung_2

zhe_point = ('fidelity', 0.0)
asha.trial_info[asha.get_id(zhe_point)] = bracket

Expand All @@ -409,6 +399,16 @@ def test_suggest_promote(self, asha, bracket, rung_0):

assert points == [(3, 0.0)]

def test_suggest_opt_out(self, asha, bracket, rung_0, rung_1, rung_2):
"""Test that ASHA opts out when last rung is full."""
asha.brackets = [bracket]
bracket.asha = asha
bracket.rungs[2] = rung_2

points = asha.suggest()

assert points is None

def test_seed_rng(self, asha):
"""Test that algo is seeded properly"""
asha.seed_rng(1)
Expand Down

0 comments on commit d206aa8

Please sign in to comment.