Skip to content

Commit

Permalink
make trial import lazy to avoid cyclic dependencies (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
Delaunay authored and bouthilx committed Jun 23, 2020
1 parent 64fd318 commit cbda12b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
11 changes: 9 additions & 2 deletions src/orion/storage/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from orion.core.io.convert import JSONConverter
from orion.core.io.database import Database, OutdatedDatabaseError
import orion.core.utils.backward as backward
from orion.core.worker.trial import Trial
from orion.storage.base import BaseStorageProtocol, FailedUpdate, MissingArguments

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -127,6 +126,8 @@ def fetch_trials(self, experiment=None, uid=None):

def _fetch_trials(self, query, selection=None):
"""See :func:`~orion.storage.BaseStorageProtocol.fetch_trials`"""
from orion.core.worker.trial import Trial

def sort_key(item):
submit_time = item.submit_time
if submit_time is None:
Expand Down Expand Up @@ -167,6 +168,8 @@ def retrieve_result(self, trial, results_file=None, **kwargs):
This does not update the database!
"""
from orion.core.worker.trial import Trial

if results_file is None:
return trial

Expand All @@ -183,6 +186,8 @@ def retrieve_result(self, trial, results_file=None, **kwargs):

def get_trial(self, trial=None, uid=None):
"""See :func:`~orion.storage.BaseStorageProtocol.get_trial`"""
from orion.core.worker.trial import Trial

if trial is not None and uid is not None:
assert trial._id == uid

Expand All @@ -198,7 +203,7 @@ def get_trial(self, trial=None, uid=None):

return Trial(**result[0])

def _update_trial(self, trial: Trial, where=None, **kwargs):
def _update_trial(self, trial, where=None, **kwargs):
"""See :func:`~orion.storage.BaseStorageProtocol.update_trial`"""
if where is None:
where = dict()
Expand Down Expand Up @@ -256,6 +261,8 @@ def fetch_pending_trials(self, experiment):

def reserve_trial(self, experiment):
"""See :func:`~orion.storage.BaseStorageProtocol.reserve_trial`"""
from orion.core.worker.trial import Trial

query = dict(
experiment=experiment._id,
status={'$in': ['interrupted', 'new', 'suspended']}
Expand Down
7 changes: 6 additions & 1 deletion src/orion/storage/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import warnings

from orion.core.io.database import DuplicateKeyError
from orion.core.worker.trial import Trial as OrionTrial
from orion.storage.base import BaseStorageProtocol, FailedUpdate, MissingArguments

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -167,6 +166,8 @@ def params(self):
@property
def _params(self):
"""See `~orion.core.worker.trial.Trial`"""
from orion.core.worker.trial import Trial as OrionTrial

if self.memory is not None:
return self.memory._params

Expand Down Expand Up @@ -221,6 +222,8 @@ def lie(self):
@property
def objective(self):
"""See `~orion.core.worker.trial.Trial`"""
from orion.core.worker.trial import Trial as OrionTrial

def result(val):
return OrionTrial.Result(name=self.objective_key, value=val, type='objective')

Expand Down Expand Up @@ -251,6 +254,8 @@ def result(val):
@property
def results(self):
"""See `~orion.core.worker.trial.Trial`"""
from orion.core.worker.trial import Trial as OrionTrial

self._results = []

for k, values in self.storage.metrics.items():
Expand Down

0 comments on commit cbda12b

Please sign in to comment.