From 9d46c9d0d2bdc0fab23dc6c5462c44659f486eb9 Mon Sep 17 00:00:00 2001 From: "Jens H. Nielsen" Date: Thu, 5 Sep 2024 17:01:23 +0200 Subject: [PATCH 1/2] Extend doNd to support squashing not squashing output type --- src/qcodes/dataset/dond/do_nd.py | 52 +++++++++++++++++++++++++- src/qcodes/dataset/dond/do_nd_utils.py | 12 +++--- tests/dataset/dond/test_doNd.py | 35 +++++++++++++++++ 3 files changed, 91 insertions(+), 8 deletions(-) diff --git a/src/qcodes/dataset/dond/do_nd.py b/src/qcodes/dataset/dond/do_nd.py index cdf20e9a53e..a0342fe1e7d 100644 --- a/src/qcodes/dataset/dond/do_nd.py +++ b/src/qcodes/dataset/dond/do_nd.py @@ -6,7 +6,7 @@ from collections.abc import Callable, Mapping, Sequence from contextlib import ExitStack from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, cast, overload import numpy as np from opentelemetry import trace @@ -567,6 +567,46 @@ def parameters(self) -> tuple[ParameterBase, ...]: return self._parameters +@overload +def dond( + *params: AbstractSweep | TogetherSweep | ParamMeasT | Sequence[ParamMeasT], + write_period: float | None = None, + measurement_name: str | Sequence[str] = "", + exp: Experiment | Sequence[Experiment] | None = None, + enter_actions: ActionsT = (), + exit_actions: ActionsT = (), + do_plot: bool | None = None, + show_progress: bool | None = None, + use_threads: bool | None = None, + additional_setpoints: Sequence[ParameterBase] = tuple(), + log_info: str | None = None, + break_condition: BreakConditionT | None = None, + dataset_dependencies: Mapping[str, Sequence[ParamMeasT]] | None = None, + in_memory_cache: bool | None = None, + squeeze: Literal[False] = False, +) -> MultiAxesTupleListWithDataSet: ... + + +@overload +def dond( + *params: AbstractSweep | TogetherSweep | ParamMeasT | Sequence[ParamMeasT], + write_period: float | None = None, + measurement_name: str | Sequence[str] = "", + exp: Experiment | Sequence[Experiment] | None = None, + enter_actions: ActionsT = (), + exit_actions: ActionsT = (), + do_plot: bool | None = None, + show_progress: bool | None = None, + use_threads: bool | None = None, + additional_setpoints: Sequence[ParameterBase] = tuple(), + log_info: str | None = None, + break_condition: BreakConditionT | None = None, + dataset_dependencies: Mapping[str, Sequence[ParamMeasT]] | None = None, + in_memory_cache: bool | None = None, + squeeze: Literal[True] = True, +) -> AxesTupleListWithDataSet | MultiAxesTupleListWithDataSet: ... + + @TRACER.start_as_current_span("qcodes.dataset.dond") def dond( *params: AbstractSweep | TogetherSweep | ParamMeasT | Sequence[ParamMeasT], @@ -583,6 +623,7 @@ def dond( break_condition: BreakConditionT | None = None, dataset_dependencies: Mapping[str, Sequence[ParamMeasT]] | None = None, in_memory_cache: bool | None = None, + squeeze: bool = True, ) -> AxesTupleListWithDataSet | MultiAxesTupleListWithDataSet: """ Perform n-dimentional scan from slowest (first) to the fastest (last), to @@ -653,6 +694,13 @@ def dond( plotting and exporting. Useful to disable if the data is very large in order to save on memory consumption. If ``None``, the value for this will be read from ``qcodesrc.json`` config file. + squeeze: If True, will return a tuple of QCoDeS DataSet, Matplotlib axis, + Matplotlib colorbar if only one group of measurements was performed + and a tuple of tuples of these if more than one group of measurements + was performed. If False, will always return a tuple where the first + member is a tuple of QCoDeS DataSet(s) and the second member is a tuple + of Matplotlib axis(es) and the third member is a tuple of Matplotlib + colorbar(s). Returns: A tuple of QCoDeS DataSet, Matplotlib axis, Matplotlib colorbar. If @@ -764,7 +812,7 @@ def dond( plots_axes.append(plot_axis) plots_colorbar.append(plot_color) - if len(measurements.groups) == 1: + if len(measurements.groups) == 1 and squeeze is True: return datasets[0], plots_axes[0], plots_colorbar[0] else: return tuple(datasets), tuple(plots_axes), tuple(plots_colorbar) diff --git a/src/qcodes/dataset/dond/do_nd_utils.py b/src/qcodes/dataset/dond/do_nd_utils.py index 29bed026fdb..2291a400f05 100644 --- a/src/qcodes/dataset/dond/do_nd_utils.py +++ b/src/qcodes/dataset/dond/do_nd_utils.py @@ -3,7 +3,7 @@ import logging from collections.abc import Callable, Iterator, Sequence from contextlib import contextmanager -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING if TYPE_CHECKING: import matplotlib.axes @@ -27,17 +27,17 @@ AxesTuple = tuple["matplotlib.axes.Axes", "matplotlib.colorbar.Colorbar"] AxesTupleList = tuple[ - list["matplotlib.axes.Axes"], list[Optional["matplotlib.colorbar.Colorbar"]] + list["matplotlib.axes.Axes"], list["matplotlib.colorbar.Colorbar | None"] ] AxesTupleListWithDataSet = tuple[ DataSetProtocol, - tuple[Optional["matplotlib.axes.Axes"], ...], - tuple[Optional["matplotlib.colorbar.Colorbar"], ...], + tuple["matplotlib.axes.Axes | None", ...], + tuple["matplotlib.colorbar.Colorbar | None", ...], ] MultiAxesTupleListWithDataSet = tuple[ tuple[DataSetProtocol, ...], - tuple[tuple[Optional["matplotlib.axes.Axes"], ...], ...], - tuple[tuple[Optional["matplotlib.colorbar.Colorbar"], ...], ...], + tuple[tuple["matplotlib.axes.Axes | None", ...], ...], + tuple[tuple["matplotlib.colorbar.Colorbar | None", ...], ...], ] diff --git a/tests/dataset/dond/test_doNd.py b/tests/dataset/dond/test_doNd.py index ebdfe129ee1..1004edfddb1 100644 --- a/tests/dataset/dond/test_doNd.py +++ b/tests/dataset/dond/test_doNd.py @@ -8,11 +8,13 @@ import hypothesis.strategies as hst import matplotlib import matplotlib.axes +import matplotlib.colorbar import matplotlib.pyplot as plt import numpy as np import pytest from hypothesis import HealthCheck, given, settings from pytest import FixtureRequest, LogCaptureFixture +from typing_extensions import assert_type import qcodes as qc from qcodes import config, validators @@ -1817,3 +1819,36 @@ def test_dond_get_after_set_stores_get_value(_param_set, _param_set_2, _param) - assert a.set_count == n_points assert b.get_count == n_points assert b.set_count == 0 + + +@pytest.mark.usefixtures("plot_close", "experiment") +def test_dond_return_type(_param_set, _param) -> None: + n_points = 11 + + # test that with squeeze=False we get MultiAxesTupleListWithDataSet as the return type + dss, axs, cbs = dond( + LinSweep(_param_set, -10, -20, n_points), _param, squeeze=False + ) + + assert isinstance(dss, tuple) + assert_type(dss, tuple[DataSetProtocol, ...]) + assert len(dss) == 1 + assert isinstance(dss[0], DataSetProtocol) + + assert isinstance(axs, tuple) + assert_type( + axs, + tuple[tuple["matplotlib.axes.Axes | None", ...], ...], + ) + assert len(axs) == 1 + assert len(axs[0]) == 1 + assert axs[0][0] is None + + assert isinstance(cbs, tuple) + assert_type( + cbs, + tuple[tuple["matplotlib.colorbar.Colorbar | None", ...], ...], + ) + assert len(cbs) == 1 + assert len(cbs[0]) == 1 + assert cbs[0][0] is None From c38330df2756d3a44d1a5d637d9f017ab38c1032 Mon Sep 17 00:00:00 2001 From: "Jens H. Nielsen" Date: Fri, 6 Sep 2024 10:22:52 +0200 Subject: [PATCH 2/2] Add changes for 6422 --- docs/changes/newsfragments/6422.improved | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 docs/changes/newsfragments/6422.improved diff --git a/docs/changes/newsfragments/6422.improved b/docs/changes/newsfragments/6422.improved new file mode 100644 index 00000000000..dbff954f311 --- /dev/null +++ b/docs/changes/newsfragments/6422.improved @@ -0,0 +1,3 @@ +`dond` now takes an optional `squeeze` flag as input. Inspired by Matplotlib's `plt.subplots` argument +of the same name, this allows the user to always get the same type returned from the function if set to False. +This makes it easier to write type checked code that uses `dond` as a function.