Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make trial import lazy to avoid cyclic dependencies #316

Merged
merged 1 commit into from
Nov 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/orion/storage/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from orion.core.io.convert import JSONConverter
from orion.core.io.database import Database, OutdatedDatabaseError
import orion.core.utils.backward as backward
from orion.core.worker.trial import Trial
from orion.storage.base import BaseStorageProtocol, FailedUpdate, MissingArguments

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -127,6 +126,8 @@ def fetch_trials(self, experiment=None, uid=None):

def _fetch_trials(self, query, selection=None):
"""See :func:`~orion.storage.BaseStorageProtocol.fetch_trials`"""
from orion.core.worker.trial import Trial

def sort_key(item):
submit_time = item.submit_time
if submit_time is None:
Expand Down Expand Up @@ -167,6 +168,8 @@ def retrieve_result(self, trial, results_file=None, **kwargs):
This does not update the database!

"""
from orion.core.worker.trial import Trial

if results_file is None:
return trial

Expand All @@ -183,6 +186,8 @@ def retrieve_result(self, trial, results_file=None, **kwargs):

def get_trial(self, trial=None, uid=None):
"""See :func:`~orion.storage.BaseStorageProtocol.get_trial`"""
from orion.core.worker.trial import Trial

if trial is not None and uid is not None:
assert trial._id == uid

Expand All @@ -198,7 +203,7 @@ def get_trial(self, trial=None, uid=None):

return Trial(**result[0])

def _update_trial(self, trial: Trial, where=None, **kwargs):
def _update_trial(self, trial, where=None, **kwargs):
"""See :func:`~orion.storage.BaseStorageProtocol.update_trial`"""
if where is None:
where = dict()
Expand Down Expand Up @@ -256,6 +261,8 @@ def fetch_pending_trials(self, experiment):

def reserve_trial(self, experiment):
"""See :func:`~orion.storage.BaseStorageProtocol.reserve_trial`"""
from orion.core.worker.trial import Trial

query = dict(
experiment=experiment._id,
status={'$in': ['interrupted', 'new', 'suspended']}
Expand Down
7 changes: 6 additions & 1 deletion src/orion/storage/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import warnings

from orion.core.io.database import DuplicateKeyError
from orion.core.worker.trial import Trial as OrionTrial
from orion.storage.base import BaseStorageProtocol, FailedUpdate, MissingArguments

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -167,6 +166,8 @@ def params(self):
@property
def _params(self):
"""See `~orion.core.worker.trial.Trial`"""
from orion.core.worker.trial import Trial as OrionTrial

if self.memory is not None:
return self.memory._params

Expand Down Expand Up @@ -221,6 +222,8 @@ def lie(self):
@property
def objective(self):
"""See `~orion.core.worker.trial.Trial`"""
from orion.core.worker.trial import Trial as OrionTrial

def result(val):
return OrionTrial.Result(name=self.objective_key, value=val, type='objective')

Expand Down Expand Up @@ -251,6 +254,8 @@ def result(val):
@property
def results(self):
"""See `~orion.core.worker.trial.Trial`"""
from orion.core.worker.trial import Trial as OrionTrial

self._results = []

for k, values in self.storage.metrics.items():
Expand Down