Skip to content

Commit

Permalink
Merge pull request #6073 from qutech/bugfix/multi_channel_instrument_…
Browse files Browse the repository at this point in the history
…parameter_setter

Accept sequences of values for setting MultiChannelInstrumentParameter
  • Loading branch information
jenshnielsen authored May 27, 2024
2 parents 4a1e3cd + 65069d2 commit a01c543
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/6073.improved
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Accept sequences of values for setting `MultiChannelInstrumentParameter` s. Previously, the behavior was inconsistent since `param.set(param.get())` would error.
32 changes: 27 additions & 5 deletions src/qcodes/parameters/multi_channel_instrument_parameter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import logging
import sys
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from .multi_parameter import MultiParameter
Expand All @@ -12,6 +14,7 @@
from .parameter_base import ParamRawDataType

InstrumentModuleType = TypeVar("InstrumentModuleType", bound="InstrumentModule")
_LOG = logging.getLogger(__name__)


class MultiChannelInstrumentParameter(MultiParameter, Generic[InstrumentModuleType]):
Expand Down Expand Up @@ -45,16 +48,35 @@ def get_raw(self) -> tuple[ParamRawDataType, ...]:
"""
return tuple(chan.parameters[self._param_name].get() for chan in self._channels)

def set_raw(self, value: ParamRawDataType) -> None:
def set_raw(self, value: ParamRawDataType | Sequence[ParamRawDataType]) -> None:
"""
Set all parameters to this value.
Set all parameters to this/these value(s).
Args:
value: The value to set to. The type is given by the
value: The value(s) to set to. The type is given by the
underlying parameter.
"""
for chan in self._channels:
getattr(chan, self._param_name).set(value)
try:
for chan in self._channels:
getattr(chan, self._param_name).set(value)
except Exception as err:
try:
# Catch wrong length of value before any setting is done
value_list = list(value)
if len(value_list) != len(self._channels):
raise ValueError
for chan, val in zip(self._channels, value_list):
getattr(chan, self._param_name).set(val)
except (TypeError, ValueError):
note = (
"Value should either be valid for a single parameter of the channel list "
"or a sequence of valid values of the same length as the list."
)
if sys.version_info >= (3, 11):
err.add_note(note)
else:
_LOG.error(note)
raise err from None

@property
def full_names(self) -> tuple[str, ...]:
Expand Down
53 changes: 53 additions & 0 deletions tests/parameter/test_multi_channel_instrument_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import sys
from typing import TYPE_CHECKING

import pytest

from qcodes.instrument_drivers.mock_instruments import DummyChannelInstrument

if TYPE_CHECKING:
from collections.abc import Generator


@pytest.fixture
def dummy_channel_instrument() -> "Generator[DummyChannelInstrument, None, None]":
instrument = DummyChannelInstrument(name="testdummy")
try:
yield instrument
finally:
instrument.close()


@pytest.fixture
def assert_raises_match() -> str:
if sys.version_info >= (3, 11):
return "Value should either be valid"
else:
return ""


def test_set_multi_channel_instrument_parameter(
dummy_channel_instrument: DummyChannelInstrument, assert_raises_match: str
):
"""Tests :class:`MultiChannelInstrumentParameter` set method."""
for name, param in dummy_channel_instrument.channels[0].parameters.items():
if not param.settable:
continue

channel_parameter = getattr(dummy_channel_instrument.channels, name)

getval = channel_parameter.get()

channel_parameter.set(getval[0])

# Assert channel parameter setters accept what the getter returns (PR #6073)
channel_parameter.set(getval)

with pytest.raises(TypeError, match=assert_raises_match):
channel_parameter.set(getval[:-1])

with pytest.raises(TypeError, match=assert_raises_match):
channel_parameter.set(getval + (getval[-1],))

with pytest.raises(TypeError):
channel_parameter.set(object())

0 comments on commit a01c543

Please sign in to comment.