Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend doNd to support squashing not squashing output type #6422

Merged
merged 2 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/changes/newsfragments/6422.improved
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
`dond` now takes an optional `squeeze` flag as input. Inspired by Matplotlib's `plt.subplots` argument
of the same name, this allows the user to always get the same type returned from the function if set to False.
This makes it easier to write type checked code that uses `dond` as a function.
52 changes: 50 additions & 2 deletions src/qcodes/dataset/dond/do_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Callable, Mapping, Sequence
from contextlib import ExitStack
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Literal, cast, overload

import numpy as np
from opentelemetry import trace
Expand Down Expand Up @@ -567,6 +567,46 @@ def parameters(self) -> tuple[ParameterBase, ...]:
return self._parameters


@overload
def dond(
*params: AbstractSweep | TogetherSweep | ParamMeasT | Sequence[ParamMeasT],
write_period: float | None = None,
measurement_name: str | Sequence[str] = "",
exp: Experiment | Sequence[Experiment] | None = None,
enter_actions: ActionsT = (),
exit_actions: ActionsT = (),
do_plot: bool | None = None,
show_progress: bool | None = None,
use_threads: bool | None = None,
additional_setpoints: Sequence[ParameterBase] = tuple(),
log_info: str | None = None,
break_condition: BreakConditionT | None = None,
dataset_dependencies: Mapping[str, Sequence[ParamMeasT]] | None = None,
in_memory_cache: bool | None = None,
squeeze: Literal[False] = False,
) -> MultiAxesTupleListWithDataSet: ...


@overload
def dond(
*params: AbstractSweep | TogetherSweep | ParamMeasT | Sequence[ParamMeasT],
write_period: float | None = None,
measurement_name: str | Sequence[str] = "",
exp: Experiment | Sequence[Experiment] | None = None,
enter_actions: ActionsT = (),
exit_actions: ActionsT = (),
do_plot: bool | None = None,
show_progress: bool | None = None,
use_threads: bool | None = None,
additional_setpoints: Sequence[ParameterBase] = tuple(),
log_info: str | None = None,
break_condition: BreakConditionT | None = None,
dataset_dependencies: Mapping[str, Sequence[ParamMeasT]] | None = None,
in_memory_cache: bool | None = None,
squeeze: Literal[True] = True,
) -> AxesTupleListWithDataSet | MultiAxesTupleListWithDataSet: ...


@TRACER.start_as_current_span("qcodes.dataset.dond")
def dond(
*params: AbstractSweep | TogetherSweep | ParamMeasT | Sequence[ParamMeasT],
Expand All @@ -583,6 +623,7 @@ def dond(
break_condition: BreakConditionT | None = None,
dataset_dependencies: Mapping[str, Sequence[ParamMeasT]] | None = None,
in_memory_cache: bool | None = None,
squeeze: bool = True,
) -> AxesTupleListWithDataSet | MultiAxesTupleListWithDataSet:
"""
Perform n-dimentional scan from slowest (first) to the fastest (last), to
Expand Down Expand Up @@ -653,6 +694,13 @@ def dond(
plotting and exporting. Useful to disable if the data is very large
in order to save on memory consumption.
If ``None``, the value for this will be read from ``qcodesrc.json`` config file.
squeeze: If True, will return a tuple of QCoDeS DataSet, Matplotlib axis,
Matplotlib colorbar if only one group of measurements was performed
and a tuple of tuples of these if more than one group of measurements
was performed. If False, will always return a tuple where the first
member is a tuple of QCoDeS DataSet(s) and the second member is a tuple
of Matplotlib axis(es) and the third member is a tuple of Matplotlib
colorbar(s).

Returns:
A tuple of QCoDeS DataSet, Matplotlib axis, Matplotlib colorbar. If
Expand Down Expand Up @@ -764,7 +812,7 @@ def dond(
plots_axes.append(plot_axis)
plots_colorbar.append(plot_color)

if len(measurements.groups) == 1:
if len(measurements.groups) == 1 and squeeze is True:
return datasets[0], plots_axes[0], plots_colorbar[0]
else:
return tuple(datasets), tuple(plots_axes), tuple(plots_colorbar)
Expand Down
12 changes: 6 additions & 6 deletions src/qcodes/dataset/dond/do_nd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from collections.abc import Callable, Iterator, Sequence
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import matplotlib.axes
Expand All @@ -27,17 +27,17 @@

AxesTuple = tuple["matplotlib.axes.Axes", "matplotlib.colorbar.Colorbar"]
AxesTupleList = tuple[
list["matplotlib.axes.Axes"], list[Optional["matplotlib.colorbar.Colorbar"]]
list["matplotlib.axes.Axes"], list["matplotlib.colorbar.Colorbar | None"]
]
AxesTupleListWithDataSet = tuple[
DataSetProtocol,
tuple[Optional["matplotlib.axes.Axes"], ...],
tuple[Optional["matplotlib.colorbar.Colorbar"], ...],
tuple["matplotlib.axes.Axes | None", ...],
tuple["matplotlib.colorbar.Colorbar | None", ...],
]
MultiAxesTupleListWithDataSet = tuple[
tuple[DataSetProtocol, ...],
tuple[tuple[Optional["matplotlib.axes.Axes"], ...], ...],
tuple[tuple[Optional["matplotlib.colorbar.Colorbar"], ...], ...],
tuple[tuple["matplotlib.axes.Axes | None", ...], ...],
tuple[tuple["matplotlib.colorbar.Colorbar | None", ...], ...],
]


Expand Down
35 changes: 35 additions & 0 deletions tests/dataset/dond/test_doNd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import hypothesis.strategies as hst
import matplotlib
import matplotlib.axes
import matplotlib.colorbar
import matplotlib.pyplot as plt
import numpy as np
import pytest
from hypothesis import HealthCheck, given, settings
from pytest import FixtureRequest, LogCaptureFixture
from typing_extensions import assert_type

import qcodes as qc
from qcodes import config, validators
Expand Down Expand Up @@ -1817,3 +1819,36 @@ def test_dond_get_after_set_stores_get_value(_param_set, _param_set_2, _param) -
assert a.set_count == n_points
assert b.get_count == n_points
assert b.set_count == 0


@pytest.mark.usefixtures("plot_close", "experiment")
def test_dond_return_type(_param_set, _param) -> None:
n_points = 11

# test that with squeeze=False we get MultiAxesTupleListWithDataSet as the return type
dss, axs, cbs = dond(
LinSweep(_param_set, -10, -20, n_points), _param, squeeze=False
)

assert isinstance(dss, tuple)
assert_type(dss, tuple[DataSetProtocol, ...])
assert len(dss) == 1
assert isinstance(dss[0], DataSetProtocol)

assert isinstance(axs, tuple)
assert_type(
axs,
tuple[tuple["matplotlib.axes.Axes | None", ...], ...],
)
assert len(axs) == 1
assert len(axs[0]) == 1
assert axs[0][0] is None

assert isinstance(cbs, tuple)
assert_type(
cbs,
tuple[tuple["matplotlib.colorbar.Colorbar | None", ...], ...],
)
assert len(cbs) == 1
assert len(cbs[0]) == 1
assert cbs[0][0] is None
Loading