Skip to content

Commit

Permalink
Merge pull request #1445 from jenshnielsen/db_metadata
Browse files Browse the repository at this point in the history
Ensure that partial run creation is completly rolled back
  • Loading branch information
jenshnielsen authored Jan 15, 2019
2 parents 3b5e79d + 0c2183a commit 74527ae
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 17 deletions.
19 changes: 10 additions & 9 deletions qcodes/dataset/sqlite_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2290,15 +2290,16 @@ def create_run(conn: ConnectionPlus, exp_id: int, name: str,
- formatted_name: the name of the newly created table
"""

run_counter, formatted_name, run_id = _insert_run(conn,
exp_id,
name,
guid,
parameters)
if metadata:
add_meta_data(conn, run_id, metadata)
_update_experiment_run_counter(conn, exp_id, run_counter)
_create_run_table(conn, formatted_name, parameters, values)
with atomic(conn):
run_counter, formatted_name, run_id = _insert_run(conn,
exp_id,
name,
guid,
parameters)
if metadata:
add_meta_data(conn, run_id, metadata)
_update_experiment_run_counter(conn, exp_id, run_counter)
_create_run_table(conn, formatted_name, parameters, values)

return run_counter, run_id, formatted_name

Expand Down
84 changes: 76 additions & 8 deletions qcodes/tests/dataset/test_sqlite_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from sqlite3 import OperationalError
import tempfile
import os
from contextlib import contextmanager

import pytest
import hypothesis.strategies as hst
from hypothesis import given
import unicodedata
import numpy as np
from unittest.mock import patch

from qcodes.dataset.descriptions import RunDescriber
from qcodes.dataset.dependencies import InterDependencies
Expand All @@ -31,8 +33,18 @@
_unicode_categories = ('Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Nd', 'Pc', 'Pd', 'Zs')


def test_path_to_dbfile():
@contextmanager
def shadow_conn(path_to_db: str):
"""
Simple context manager to create a connection for testing and
close it on exit
"""
conn = mut.connect(path_to_db)
yield conn
conn.close()


def test_path_to_dbfile():
with tempfile.TemporaryDirectory() as tempdir:
tempdb = os.path.join(tempdir, 'database.db')
conn = mut.connect(tempdb)
Expand Down Expand Up @@ -100,7 +112,6 @@ def test__validate_table_raises(table_name):


def test_get_dependents(experiment):

x = ParamSpec('x', 'numeric')
t = ParamSpec('t', 'numeric')
y = ParamSpec('y', 'numeric', depends_on=['x', 't'])
Expand Down Expand Up @@ -142,7 +153,8 @@ def test_get_dependents(experiment):

def test_column_in_table(dataset):
assert mut.is_column_in_table(dataset.conn, "runs", "run_id")
assert not mut.is_column_in_table(dataset.conn, "runs", "non-existing-column")
assert not mut.is_column_in_table(dataset.conn, "runs",
"non-existing-column")


def test_run_exist(dataset):
Expand All @@ -168,7 +180,6 @@ def test_get_last_experiment_no_experiments(empty_temp_db):


def test_update_runs_description(dataset):

invalid_descs = ['{}', 'description']

for idesc in invalid_descs:
Expand Down Expand Up @@ -212,16 +223,17 @@ def test_get_parameter_data(scalar_dataset):
expected_names['param_3'] = ['param_0', 'param_1', 'param_2',
'param_3']
expected_shapes = {}
expected_shapes['param_3'] = [(10**3, )]*4
expected_shapes['param_3'] = [(10 ** 3,)] * 4

expected_values = {}
expected_values['param_3'] = [np.arange(10000*a, 10000*a+1000)
expected_values['param_3'] = [np.arange(10000 * a, 10000 * a + 1000)
for a in range(4)]
verify_data_dict(data, input_names, expected_names, expected_shapes,
expected_values)


def test_get_parameter_data_independent_parameters(standalone_parameters_dataset):
def test_get_parameter_data_independent_parameters(
standalone_parameters_dataset):
ds = standalone_parameters_dataset
params = mut.get_non_dependencies(ds.conn,
ds.run_id)
Expand All @@ -240,7 +252,7 @@ def test_get_parameter_data_independent_parameters(standalone_parameters_dataset
expected_shapes = {}
expected_shapes['param_1'] = [(10 ** 3,)]
expected_shapes['param_2'] = [(10 ** 3,)]
expected_shapes['param_3'] = [(10**3, )]*2
expected_shapes['param_3'] = [(10 ** 3,)] * 2

expected_values = {}
expected_values['param_1'] = [np.arange(10000, 10000 + 1000)]
Expand Down Expand Up @@ -270,3 +282,59 @@ def test_is_run_id_in_db(empty_temp_db):
acquired_dict = mut.is_run_id_in_database(conn, *try_ids)

assert expected_dict == acquired_dict


def test_atomic_creation(experiment):
""""
Test that dataset creation is atomic. Test for
https://github.com/QCoDeS/Qcodes/issues/1444
"""

def just_throw(*args):
raise RuntimeError("This breaks adding metadata")

# first we patch add_meta_data to throw an exception
# if create_data is not atomic this would create a partial
# run in the db. Causing the next create_run to fail
with patch('qcodes.dataset.sqlite_base.add_meta_data', new=just_throw):
x = ParamSpec('x', 'numeric')
t = ParamSpec('t', 'numeric')
y = ParamSpec('y', 'numeric', depends_on=['x', 't'])
with pytest.raises(RuntimeError,
match="Rolling back due to unhandled exception")as e:
mut.create_run(experiment.conn,
experiment.exp_id,
name='testrun',
guid=generate_guid(),
parameters=[x, t,
y],
metadata={'a': 1})
assert error_caused_by(e, "This breaks adding metadata")
# since we are starting from an empty database and the above transaction
# should be rolled back there should be no runs in the run table
runs = mut.transaction(experiment.conn,
'SELECT run_id FROM runs').fetchall()
assert len(runs) == 0
with shadow_conn(experiment.path_to_db) as new_conn:
runs = mut.transaction(new_conn,
'SELECT run_id FROM runs').fetchall()
assert len(runs) == 0

# if the above was not correctly rolled back we
# expect the next creation of a run to fail
mut.create_run(experiment.conn,
experiment.exp_id,
name='testrun',
guid=generate_guid(),
parameters=[x, t,
y],
metadata={'a': 1})

runs = mut.transaction(experiment.conn,
'SELECT run_id FROM runs').fetchall()
assert len(runs) == 1

with shadow_conn(experiment.path_to_db) as new_conn:
runs = mut.transaction(new_conn,
'SELECT run_id FROM runs').fetchall()
assert len(runs) == 1

0 comments on commit 74527ae

Please sign in to comment.