Skip to content

Commit

Permalink
Add uid to storage.update_experiment
Browse files Browse the repository at this point in the history
Why:

`storage.fetch_experiments` returns documents, not `Experiment` objects.
This makes it complicated to just fetch and update an experiment given
that `update_experiment` expects an `Experiment` object.
  • Loading branch information
bouthilx committed Oct 9, 2019
1 parent 643f0f6 commit 1cecc30
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 5 deletions.
17 changes: 14 additions & 3 deletions src/orion/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ def create_experiment(self, config):
"""Insert a new experiment inside the database"""
raise NotImplementedError()

def update_experiment(self, experiment, where=None, **kwargs):
def update_experiment(self, experiment=None, uid=None, where=None, **kwargs):
"""Update a the fields of a given trials
Parameters
----------
experiment: Experiment
Experiment object to update
experiment: Experiment, optional
experiment object to retrieve from the database
uid: str, optional
experiment id used to retrieve the trial object
where: Optional[dict]
constraint experiment must respect
Expand All @@ -53,6 +56,14 @@ def update_experiment(self, experiment, where=None, **kwargs):
-------
returns true if the underlying storage was updated
Raises
------
UndefinedCall
if both experiment and uid are not set
AssertionError
if both experiment and uid are provided and they do not match
"""
raise NotImplementedError()

Expand Down
13 changes: 11 additions & 2 deletions src/orion/storage/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,21 @@ def create_experiment(self, config):
"""See :func:`~orion.storage.BaseStorageProtocol.create_experiment`"""
return self._db.write('experiments', data=config, query=None)

def update_experiment(self, experiment, where=None, **kwargs):
def update_experiment(self, experiment=None, uid=None, where=None, **kwargs):
"""See :func:`~orion.storage.BaseStorageProtocol.update_experiment`"""
if experiment is not None and uid is not None:
assert experiment._id == uid

if uid is None:
if experiment is None:
raise MissingArguments('Either `experiment` or `uid` should be set')

uid = experiment._id

if where is None:
where = dict()

where['_id'] = experiment._id
where['_id'] = uid
return self._db.write('experiments', data=kwargs, query=where)

def fetch_experiments(self, query, selection=None):
Expand Down
26 changes: 26 additions & 0 deletions tests/unittests/storage/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,32 @@ def test_fetch_experiments(self, storage, name='0', user='a'):
experiments = storage.fetch_experiments({'name': '-1', 'metadata.user': user})
assert len(experiments) == 0

def test_update_experiment(self, monkeypatch, storage, name='0', user='a'):
"""Test fetch experiments"""
with OrionState(experiments=generate_experiments(), database=storage) as cfg:
storage = cfg.storage()

class _Dummy():
pass

experiment = cfg.experiments[0]
mocked_experiment = _Dummy()
mocked_experiment._id = experiment['_id']

storage.update_experiment(mocked_experiment, test=True)
assert storage.fetch_experiments({'_id': experiment['_id']})[0]['test']
assert 'test' not in storage.fetch_experiments({'_id': cfg.experiments[1]['_id']})[0]

storage.update_experiment(uid=experiment['_id'], test2=True)
assert storage.fetch_experiments({'_id': experiment['_id']})[0]['test2']
assert 'test2' not in storage.fetch_experiments({'_id': cfg.experiments[1]['_id']})[0]

with pytest.raises(MissingArguments):
storage.update_experiment()

with pytest.raises(AssertionError):
storage.update_experiment(experiment=mocked_experiment, uid='123')

def test_register_trial(self, storage):
"""Test register trial"""
with OrionState(experiments=[base_experiment], database=storage) as cfg:
Expand Down

0 comments on commit 1cecc30

Please sign in to comment.