diff --git a/qcodes/dataset/sqlite_base.py b/qcodes/dataset/sqlite_base.py index 9c18f02f236..fee22744eac 100644 --- a/qcodes/dataset/sqlite_base.py +++ b/qcodes/dataset/sqlite_base.py @@ -2290,15 +2290,16 @@ def create_run(conn: ConnectionPlus, exp_id: int, name: str, - formatted_name: the name of the newly created table """ - run_counter, formatted_name, run_id = _insert_run(conn, - exp_id, - name, - guid, - parameters) - if metadata: - add_meta_data(conn, run_id, metadata) - _update_experiment_run_counter(conn, exp_id, run_counter) - _create_run_table(conn, formatted_name, parameters, values) + with atomic(conn): + run_counter, formatted_name, run_id = _insert_run(conn, + exp_id, + name, + guid, + parameters) + if metadata: + add_meta_data(conn, run_id, metadata) + _update_experiment_run_counter(conn, exp_id, run_counter) + _create_run_table(conn, formatted_name, parameters, values) return run_counter, run_id, formatted_name diff --git a/qcodes/tests/dataset/test_sqlite_base.py b/qcodes/tests/dataset/test_sqlite_base.py index a5e722f8fde..256072160b8 100644 --- a/qcodes/tests/dataset/test_sqlite_base.py +++ b/qcodes/tests/dataset/test_sqlite_base.py @@ -4,12 +4,14 @@ from sqlite3 import OperationalError import tempfile import os +from contextlib import contextmanager import pytest import hypothesis.strategies as hst from hypothesis import given import unicodedata import numpy as np +from unittest.mock import patch from qcodes.dataset.descriptions import RunDescriber from qcodes.dataset.dependencies import InterDependencies @@ -31,8 +33,18 @@ _unicode_categories = ('Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Nd', 'Pc', 'Pd', 'Zs') -def test_path_to_dbfile(): +@contextmanager +def shadow_conn(path_to_db: str): + """ + Simple context manager to create a connection for testing and + close it on exit + """ + conn = mut.connect(path_to_db) + yield conn + conn.close() + +def test_path_to_dbfile(): with tempfile.TemporaryDirectory() as tempdir: tempdb = os.path.join(tempdir, 'database.db') conn = mut.connect(tempdb) @@ -100,7 +112,6 @@ def test__validate_table_raises(table_name): def test_get_dependents(experiment): - x = ParamSpec('x', 'numeric') t = ParamSpec('t', 'numeric') y = ParamSpec('y', 'numeric', depends_on=['x', 't']) @@ -142,7 +153,8 @@ def test_get_dependents(experiment): def test_column_in_table(dataset): assert mut.is_column_in_table(dataset.conn, "runs", "run_id") - assert not mut.is_column_in_table(dataset.conn, "runs", "non-existing-column") + assert not mut.is_column_in_table(dataset.conn, "runs", + "non-existing-column") def test_run_exist(dataset): @@ -168,7 +180,6 @@ def test_get_last_experiment_no_experiments(empty_temp_db): def test_update_runs_description(dataset): - invalid_descs = ['{}', 'description'] for idesc in invalid_descs: @@ -212,16 +223,17 @@ def test_get_parameter_data(scalar_dataset): expected_names['param_3'] = ['param_0', 'param_1', 'param_2', 'param_3'] expected_shapes = {} - expected_shapes['param_3'] = [(10**3, )]*4 + expected_shapes['param_3'] = [(10 ** 3,)] * 4 expected_values = {} - expected_values['param_3'] = [np.arange(10000*a, 10000*a+1000) + expected_values['param_3'] = [np.arange(10000 * a, 10000 * a + 1000) for a in range(4)] verify_data_dict(data, input_names, expected_names, expected_shapes, expected_values) -def test_get_parameter_data_independent_parameters(standalone_parameters_dataset): +def test_get_parameter_data_independent_parameters( + standalone_parameters_dataset): ds = standalone_parameters_dataset params = mut.get_non_dependencies(ds.conn, ds.run_id) @@ -240,7 +252,7 @@ def test_get_parameter_data_independent_parameters(standalone_parameters_dataset expected_shapes = {} expected_shapes['param_1'] = [(10 ** 3,)] expected_shapes['param_2'] = [(10 ** 3,)] - expected_shapes['param_3'] = [(10**3, )]*2 + expected_shapes['param_3'] = [(10 ** 3,)] * 2 expected_values = {} expected_values['param_1'] = [np.arange(10000, 10000 + 1000)] @@ -270,3 +282,59 @@ def test_is_run_id_in_db(empty_temp_db): acquired_dict = mut.is_run_id_in_database(conn, *try_ids) assert expected_dict == acquired_dict + + +def test_atomic_creation(experiment): + """" + Test that dataset creation is atomic. Test for + https://github.com/QCoDeS/Qcodes/issues/1444 + """ + + def just_throw(*args): + raise RuntimeError("This breaks adding metadata") + + # first we patch add_meta_data to throw an exception + # if create_data is not atomic this would create a partial + # run in the db. Causing the next create_run to fail + with patch('qcodes.dataset.sqlite_base.add_meta_data', new=just_throw): + x = ParamSpec('x', 'numeric') + t = ParamSpec('t', 'numeric') + y = ParamSpec('y', 'numeric', depends_on=['x', 't']) + with pytest.raises(RuntimeError, + match="Rolling back due to unhandled exception")as e: + mut.create_run(experiment.conn, + experiment.exp_id, + name='testrun', + guid=generate_guid(), + parameters=[x, t, + y], + metadata={'a': 1}) + assert error_caused_by(e, "This breaks adding metadata") + # since we are starting from an empty database and the above transaction + # should be rolled back there should be no runs in the run table + runs = mut.transaction(experiment.conn, + 'SELECT run_id FROM runs').fetchall() + assert len(runs) == 0 + with shadow_conn(experiment.path_to_db) as new_conn: + runs = mut.transaction(new_conn, + 'SELECT run_id FROM runs').fetchall() + assert len(runs) == 0 + + # if the above was not correctly rolled back we + # expect the next creation of a run to fail + mut.create_run(experiment.conn, + experiment.exp_id, + name='testrun', + guid=generate_guid(), + parameters=[x, t, + y], + metadata={'a': 1}) + + runs = mut.transaction(experiment.conn, + 'SELECT run_id FROM runs').fetchall() + assert len(runs) == 1 + + with shadow_conn(experiment.path_to_db) as new_conn: + runs = mut.transaction(new_conn, + 'SELECT run_id FROM runs').fetchall() + assert len(runs) == 1