From 6a3130508f475ced0b3220ed763486850b8b83f3 Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Mon, 22 Apr 2024 14:32:24 -0700 Subject: [PATCH 01/13] Add Infer methods --- src/qcodes/parameters/infer.py | 109 +++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 src/qcodes/parameters/infer.py diff --git a/src/qcodes/parameters/infer.py b/src/qcodes/parameters/infer.py new file mode 100644 index 00000000000..3fbcd6d5b02 --- /dev/null +++ b/src/qcodes/parameters/infer.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar + +from qcodes.instrument import Instrument, InstrumentBase, InstrumentChannel +from qcodes.instrument.parameter import DelegateParameter, Parameter + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + +class InferError(AttributeError): ... + + +class InferAttrs: + """Holds a global set of attribute name that will be inferred""" + + _known_attrs: ClassVar[set[str]] = set() + + @classmethod + def add_attr(cls, attr: str) -> None: + cls._known_attrs.add(attr) + + @classmethod + def known_attrs(cls) -> tuple[str, ...]: + return tuple(cls._known_attrs) + + @classmethod + def discard_attr(cls, attr: str) -> None: + cls._known_attrs.discard(attr) + + @classmethod + def clear_attrs(cls) -> None: + cls._known_attrs = set() + + +def get_root_param( + param: Parameter | DelegateParameter | None, + parent_param: Parameter | None = None, + alt_source_attrs: Sequence[str] | None = None, +) -> Parameter: + """Return the root parameter in a chain of DelegateParameters or other linking Parameters + + This method recursively searches on the initial parameter. + - If the parameter is a DelegateParameter, it returns the .source. + - If the parameter is not a DelegateParameter, but has an attribute in + either alt_source_attrs or the InferAttrs class which is a parameter, + then it returns that parameter + - If the parameter is None, because the previous DelegateParameter did not have a source + it raises an InferError + + + """ + parent_param = param if parent_param is None else parent_param + if alt_source_attrs is None: + alt_source_attrs_set: Iterable[str] = InferAttrs.known_attrs() + else: + alt_source_attrs_set = set.union( + set(alt_source_attrs), set(InferAttrs.known_attrs()) + ) + + if param is None: + raise InferError(f"Parameter {parent_param} is not attached to a source") + if isinstance(param, DelegateParameter): + return get_root_param(param.source, parent_param) + for alt_source_attr in alt_source_attrs_set: + alt_source = getattr(param, alt_source_attr, None) + if alt_source is not None and isinstance(alt_source, Parameter): + return get_root_param( + alt_source, parent_param=parent_param, alt_source_attrs=alt_source_attrs + ) + return param + + +def infer_instrument( + param: Parameter, + alt_source_attrs: Sequence[str] | None = None, +) -> InstrumentBase: + """Find the instrument that owns a parameter or delegate parameter.""" + root_param = get_root_param(param, alt_source_attrs=alt_source_attrs) + instrument = get_instrument_from_param(root_param) + if isinstance(instrument, InstrumentChannel): + return instrument.root_instrument + elif isinstance(instrument, Instrument): + return instrument + + raise InferError(f"Could not determine source instrument for parameter {param}") + + +def infer_channel( + param: Parameter, + alt_source_attrs: Sequence[str] | None = None, +) -> InstrumentChannel: + """Find the instrument module that owns a parameter.""" + root_param = get_root_param(param, alt_source_attrs=alt_source_attrs) + channel = get_instrument_from_param(root_param) + if isinstance(channel, InstrumentChannel): + return channel + raise InferError( + f"Could not determine a root instrument channel for parameter {param}" + ) + + +def get_instrument_from_param( + param: Parameter, +) -> InstrumentBase: + if param.instrument is not None: + return param.instrument + raise InferError(f"Parameter {param} has no instrument") From 9886408829a50427dc7f166d78244255c9db8cd8 Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Mon, 22 Apr 2024 15:21:47 -0700 Subject: [PATCH 02/13] Infer improvements --- src/qcodes/parameters/infer.py | 91 +++++++++++++++++++++++----------- 1 file changed, 61 insertions(+), 30 deletions(-) diff --git a/src/qcodes/parameters/infer.py b/src/qcodes/parameters/infer.py index 3fbcd6d5b02..4043628372a 100644 --- a/src/qcodes/parameters/infer.py +++ b/src/qcodes/parameters/infer.py @@ -1,12 +1,15 @@ from __future__ import annotations +from collections.abc import Sequence from typing import TYPE_CHECKING, ClassVar from qcodes.instrument import Instrument, InstrumentBase, InstrumentChannel from qcodes.instrument.parameter import DelegateParameter, Parameter if TYPE_CHECKING: - from collections.abc import Iterable, Sequence + from collections.abc import Iterable + +DOES_NOT_EXIST = "Does not exist" class InferError(AttributeError): ... @@ -35,40 +38,25 @@ def clear_attrs(cls) -> None: def get_root_param( - param: Parameter | DelegateParameter | None, - parent_param: Parameter | None = None, + param: Parameter, alt_source_attrs: Sequence[str] | None = None, ) -> Parameter: - """Return the root parameter in a chain of DelegateParameters or other linking Parameters - - This method recursively searches on the initial parameter. - - If the parameter is a DelegateParameter, it returns the .source. - - If the parameter is not a DelegateParameter, but has an attribute in - either alt_source_attrs or the InferAttrs class which is a parameter, - then it returns that parameter - - If the parameter is None, because the previous DelegateParameter did not have a source - it raises an InferError - - - """ - parent_param = param if parent_param is None else parent_param - if alt_source_attrs is None: - alt_source_attrs_set: Iterable[str] = InferAttrs.known_attrs() - else: - alt_source_attrs_set = set.union( - set(alt_source_attrs), set(InferAttrs.known_attrs()) - ) + """Return the root parameter in a chain of DelegateParameters or other linking Parameters""" + alt_source_attrs_set = _merge_user_and_class_attrs(alt_source_attrs) - if param is None: - raise InferError(f"Parameter {parent_param} is not attached to a source") if isinstance(param, DelegateParameter): - return get_root_param(param.source, parent_param) + if param.source is None: + raise InferError(f"Parameter {param} is not attached to a source") + return get_root_param(param.source) + for alt_source_attr in alt_source_attrs_set: - alt_source = getattr(param, alt_source_attr, None) - if alt_source is not None and isinstance(alt_source, Parameter): - return get_root_param( - alt_source, parent_param=parent_param, alt_source_attrs=alt_source_attrs + alt_source = getattr(param, alt_source_attr, DOES_NOT_EXIST) + if alt_source is None: + raise InferError( + f"Parameter {param} is not attached to a source on attribute {alt_source_attr}" ) + elif isinstance(alt_source, Parameter): + return get_root_param(alt_source, alt_source_attrs=alt_source_attrs) return param @@ -91,7 +79,7 @@ def infer_channel( param: Parameter, alt_source_attrs: Sequence[str] | None = None, ) -> InstrumentChannel: - """Find the instrument module that owns a parameter.""" + """Find the instrument module that owns a parameter or delegate parameter""" root_param = get_root_param(param, alt_source_attrs=alt_source_attrs) channel = get_instrument_from_param(root_param) if isinstance(channel, InstrumentChannel): @@ -107,3 +95,46 @@ def get_instrument_from_param( if param.instrument is not None: return param.instrument raise InferError(f"Parameter {param} has no instrument") + + +def get_parameter_chain( + param_chain: Parameter | Sequence[Parameter], + alt_source_attrs: Sequence[str] | None = None, +) -> tuple[Parameter, ...]: + """Return the chain of DelegateParameters or other linking Parameters""" + alt_source_attrs_set = _merge_user_and_class_attrs(alt_source_attrs) + + if not isinstance(param_chain, Sequence): + param_chain = (param_chain,) + + param = param_chain[-1] + mutable_param_chain = list(param_chain) + if isinstance(param, DelegateParameter): + if param.source is None: + return tuple(param_chain) + mutable_param_chain.append(param.source) + return get_parameter_chain( + mutable_param_chain, + alt_source_attrs=alt_source_attrs, + ) + + for alt_source_attr in alt_source_attrs_set: + alt_source = getattr(param, alt_source_attr, DOES_NOT_EXIST) + if alt_source is None: + return tuple(param_chain) + elif isinstance(alt_source, Parameter): + mutable_param_chain.append(alt_source) + return get_parameter_chain( + mutable_param_chain, + alt_source_attrs=alt_source_attrs, + ) + return tuple(param_chain) + + +def _merge_user_and_class_attrs( + alt_source_attrs: Sequence[str] | None = None, +) -> Iterable[str]: + if alt_source_attrs is None: + return InferAttrs.known_attrs() + else: + return set.union(set(alt_source_attrs), set(InferAttrs.known_attrs())) From 59d6a91f9fc9181ffa406789f053bc8305d4e1b6 Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Mon, 22 Apr 2024 16:12:33 -0700 Subject: [PATCH 03/13] Clean up naming and handle str | Sequence[str] for attrs --- src/qcodes/parameters/infer.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/qcodes/parameters/infer.py b/src/qcodes/parameters/infer.py index 4043628372a..9259bfde7b7 100644 --- a/src/qcodes/parameters/infer.py +++ b/src/qcodes/parameters/infer.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, ClassVar -from qcodes.instrument import Instrument, InstrumentBase, InstrumentChannel +from qcodes.instrument import Instrument, InstrumentBase, InstrumentModule from qcodes.instrument.parameter import DelegateParameter, Parameter if TYPE_CHECKING: @@ -37,7 +37,7 @@ def clear_attrs(cls) -> None: cls._known_attrs = set() -def get_root_param( +def get_root_parameter( param: Parameter, alt_source_attrs: Sequence[str] | None = None, ) -> Parameter: @@ -47,7 +47,7 @@ def get_root_param( if isinstance(param, DelegateParameter): if param.source is None: raise InferError(f"Parameter {param} is not attached to a source") - return get_root_param(param.source) + return get_root_parameter(param.source) for alt_source_attr in alt_source_attrs_set: alt_source = getattr(param, alt_source_attr, DOES_NOT_EXIST) @@ -56,7 +56,7 @@ def get_root_param( f"Parameter {param} is not attached to a source on attribute {alt_source_attr}" ) elif isinstance(alt_source, Parameter): - return get_root_param(alt_source, alt_source_attrs=alt_source_attrs) + return get_root_parameter(alt_source, alt_source_attrs=alt_source_attrs) return param @@ -65,9 +65,9 @@ def infer_instrument( alt_source_attrs: Sequence[str] | None = None, ) -> InstrumentBase: """Find the instrument that owns a parameter or delegate parameter.""" - root_param = get_root_param(param, alt_source_attrs=alt_source_attrs) + root_param = get_root_parameter(param, alt_source_attrs=alt_source_attrs) instrument = get_instrument_from_param(root_param) - if isinstance(instrument, InstrumentChannel): + if isinstance(instrument, InstrumentModule): return instrument.root_instrument elif isinstance(instrument, Instrument): return instrument @@ -78,11 +78,11 @@ def infer_instrument( def infer_channel( param: Parameter, alt_source_attrs: Sequence[str] | None = None, -) -> InstrumentChannel: +) -> InstrumentModule: """Find the instrument module that owns a parameter or delegate parameter""" - root_param = get_root_param(param, alt_source_attrs=alt_source_attrs) + root_param = get_root_parameter(param, alt_source_attrs=alt_source_attrs) channel = get_instrument_from_param(root_param) - if isinstance(channel, InstrumentChannel): + if isinstance(channel, InstrumentModule): return channel raise InferError( f"Could not determine a root instrument channel for parameter {param}" @@ -99,7 +99,7 @@ def get_instrument_from_param( def get_parameter_chain( param_chain: Parameter | Sequence[Parameter], - alt_source_attrs: Sequence[str] | None = None, + alt_source_attrs: str | Sequence[str] | None = None, ) -> tuple[Parameter, ...]: """Return the chain of DelegateParameters or other linking Parameters""" alt_source_attrs_set = _merge_user_and_class_attrs(alt_source_attrs) @@ -132,9 +132,11 @@ def get_parameter_chain( def _merge_user_and_class_attrs( - alt_source_attrs: Sequence[str] | None = None, + alt_source_attrs: str | Sequence[str] | None = None, ) -> Iterable[str]: if alt_source_attrs is None: return InferAttrs.known_attrs() + elif isinstance(alt_source_attrs, str): + return set.union(set((alt_source_attrs,)), set(InferAttrs.known_attrs())) else: return set.union(set(alt_source_attrs), set(InferAttrs.known_attrs())) From 5a9db05abb72db6e6ebb6b4594af34425a62734c Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Tue, 23 Apr 2024 14:40:38 -0700 Subject: [PATCH 04/13] Add tests for infer --- tests/parameter/test_infer.py | 201 ++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 tests/parameter/test_infer.py diff --git a/tests/parameter/test_infer.py b/tests/parameter/test_infer.py new file mode 100644 index 00000000000..491bde7d1d2 --- /dev/null +++ b/tests/parameter/test_infer.py @@ -0,0 +1,201 @@ +from typing import Any + +import numpy as np +import pytest + +from qcodes.instrument import Instrument, InstrumentModule +from qcodes.parameters import DelegateParameter, ManualParameter, Parameter +from qcodes.parameters.infer import ( + InferAttrs, + InferError, + get_parameter_chain, + get_root_parameter, + infer_channel, + infer_instrument, +) + + +class DummyModule(InstrumentModule): + def __init__(self, name: str, parent: Instrument): + super().__init__(name=name, parent=parent) + self.good_chan_parameter = ManualParameter( + "good_chan_parameter", instrument=self + ) + self.bad_chan_parameter = ManualParameter("bad_chan_parameter") + + +class DummyInstrument(Instrument): + def __init__(self, name: str): + super().__init__(name=name) + self.good_inst_parameter = ManualParameter( + "good_inst_parameter", instrument=self + ) + self.bad_inst_parameter = ManualParameter("bad_inst_parameter") + self.module = DummyModule(name="module", parent=self) + + +class DummyDelegateInstrument(Instrument): + def __init__(self, name: str): + super().__init__(name=name) + self.inst_delegate = DelegateParameter( + name="inst_delegate", source=None, instrument=self, bind_to_instrument=True + ) + self.module = DummyDelegateModule(name="dummy_delegate_module", parent=self) + + +class DummyDelegateModule(InstrumentModule): + def __init__(self, name: str, parent: Instrument): + super().__init__(name=name, parent=parent) + self.chan_delegate = DelegateParameter( + name="chan_delegate", source=None, instrument=self, bind_to_instrument=True + ) + + +class UserLinkingParameter(Parameter): + def __init__( + self, name: str, linked_parameter: Parameter | None = None, **kwargs: Any + ): + super().__init__(name=name, **kwargs) + self.linked_parameter: Parameter | None = linked_parameter + + +@pytest.fixture(name="instrument_fixture") +def make_instrument_fixture(): + inst = DummyInstrument("dummy_instrument") + InferAttrs.clear_attrs() + try: + yield inst + finally: + inst.close() + + +@pytest.fixture(name="good_inst_delegates") +def make_good_delegate_parameters(instrument_fixture): + inst = instrument_fixture + good_inst_del_1 = DelegateParameter( + "good_inst_del_1", source=inst.good_inst_parameter + ) + good_inst_del_2 = DelegateParameter("good_inst_del_2", source=good_inst_del_1) + good_inst_del_3 = UserLinkingParameter( + "good_inst_del_3", linked_parameter=good_inst_del_2 + ) + return good_inst_del_1, good_inst_del_2, good_inst_del_3 + + +def test_get_root_parameter_valid(instrument_fixture, good_inst_delegates): + inst = instrument_fixture + good_inst_del_1, good_inst_del_2, good_inst_del_3 = good_inst_delegates + + assert get_root_parameter(good_inst_del_1) is inst.good_inst_parameter + assert get_root_parameter(good_inst_del_2) is inst.good_inst_parameter + + assert ( + get_root_parameter(good_inst_del_3, "linked_parameter") + is inst.good_inst_parameter + ) + + InferAttrs.clear_attrs() + assert get_root_parameter(good_inst_del_3) is good_inst_del_3 + + InferAttrs.add_attr("linked_parameter") + assert get_root_parameter(good_inst_del_3) is inst.good_inst_parameter + + +def test_get_root_parameter_no_source(good_inst_delegates): + good_inst_del_1, good_inst_del_2, _ = good_inst_delegates + + good_inst_del_1.source = None + + with pytest.raises(InferError): + get_root_parameter(good_inst_del_2) + + +def test_get_root_parameter_no_user_attr(good_inst_delegates): + _, _, good_inst_del_3 = good_inst_delegates + InferAttrs.clear_attrs() + assert get_root_parameter(good_inst_del_3, "external_parameter") is good_inst_del_3 + + +def test_get_root_parameter_none_user_attr(good_inst_delegates): + _, _, good_inst_del_3 = good_inst_delegates + good_inst_del_3.linked_parameter = None + with pytest.raises(InferError): + get_root_parameter(good_inst_del_3, "linked_parameter") + + +def test_infer_instrument_valid(instrument_fixture, good_inst_delegates): + inst = instrument_fixture + _, _, good_inst_del_3 = good_inst_delegates + InferAttrs.add_attr("linked_parameter") + assert infer_instrument(good_inst_del_3) is inst + + +def test_infer_instrument_no_instrument(instrument_fixture): + inst = instrument_fixture + no_inst_delegate = DelegateParameter( + "no_inst_delegate", source=inst.bad_inst_parameter + ) + with pytest.raises(InferError): + infer_instrument(no_inst_delegate) + + +def test_infer_channel_valid(instrument_fixture): + inst = instrument_fixture + chan_delegate = DelegateParameter( + "chan_delegate", source=inst.module.good_chan_parameter + ) + assert infer_channel(chan_delegate) is inst.module + + +def test_infer_channel_no_channel(instrument_fixture): + inst = instrument_fixture + no_chan_delegate = DelegateParameter( + "no_chan_delegate", source=inst.module.bad_chan_parameter + ) + with pytest.raises(InferError): + infer_instrument(no_chan_delegate) + + +def test_get_parameter_chain(instrument_fixture, good_inst_delegates): + inst = instrument_fixture + good_inst_del_1, good_inst_del_2, good_inst_del_3 = good_inst_delegates + parameter_chain = get_parameter_chain(good_inst_del_3, "linked_parameter") + assert np.all( + [ + param in parameter_chain + for param in ( + inst.good_inst_parameter, + good_inst_del_1, + good_inst_del_2, + good_inst_del_3, + ) + ] + ) + + # This is a broken chain. get_root_parameter would throw an InferError, but + # get_parameter_chain should run successfully + good_inst_del_1.source = None + parameter_chain = get_parameter_chain(good_inst_del_3, "linked_parameter") + assert np.all( + [ + param in parameter_chain + for param in ( + good_inst_del_1, + good_inst_del_2, + good_inst_del_3, + ) + ] + ) + + +def test_parameters_on_delegate_instruments(instrument_fixture, good_inst_delegates): + inst = instrument_fixture + _, good_inst_del_2, _ = good_inst_delegates + + delegate_inst = DummyDelegateInstrument("dummy_delegate_instrument") + delegate_inst.inst_delegate.source = good_inst_del_2 + delegate_inst.module.chan_delegate.source = inst.module.good_chan_parameter + + assert infer_channel(delegate_inst.module.chan_delegate) is inst.module + assert infer_instrument(delegate_inst.module.chan_delegate) is inst + assert infer_instrument(delegate_inst.inst_delegate) is inst From f545702f3bfe7510317ea19705e8745b7c569e7b Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Tue, 23 Apr 2024 16:07:38 -0700 Subject: [PATCH 05/13] Add newsfragment Fix mypy errors for 3.9 --- docs/changes/newsfragments/5998.new | 4 ++++ tests/parameter/test_infer.py | 2 ++ 2 files changed, 6 insertions(+) create mode 100644 docs/changes/newsfragments/5998.new diff --git a/docs/changes/newsfragments/5998.new b/docs/changes/newsfragments/5998.new new file mode 100644 index 00000000000..9ea1de4d653 --- /dev/null +++ b/docs/changes/newsfragments/5998.new @@ -0,0 +1,4 @@ +Add methods to recursively search a chain of DelegateParameters and return either all the parameters in the chain or the 'root' parameter +These methods may also be used with custom Parameters which link to other parameters via different attribute names + +Also add infer_channel and infer_instrument methods to find the InstrumentModule or Instrument of the root parameter \ No newline at end of file diff --git a/tests/parameter/test_infer.py b/tests/parameter/test_infer.py index 491bde7d1d2..6afc32f073a 100644 --- a/tests/parameter/test_infer.py +++ b/tests/parameter/test_infer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any import numpy as np From b25bb842382105ace6d396eaea09856d0caa07a9 Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Tue, 23 Apr 2024 16:10:54 -0700 Subject: [PATCH 06/13] pre-commit fixes --- docs/changes/newsfragments/5998.new | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changes/newsfragments/5998.new b/docs/changes/newsfragments/5998.new index 9ea1de4d653..b68ee0a99c0 100644 --- a/docs/changes/newsfragments/5998.new +++ b/docs/changes/newsfragments/5998.new @@ -1,4 +1,4 @@ Add methods to recursively search a chain of DelegateParameters and return either all the parameters in the chain or the 'root' parameter These methods may also be used with custom Parameters which link to other parameters via different attribute names -Also add infer_channel and infer_instrument methods to find the InstrumentModule or Instrument of the root parameter \ No newline at end of file +Also add infer_channel and infer_instrument methods to find the InstrumentModule or Instrument of the root parameter From 9deb55c89094ff972e2644b01d7d2a9f4f3efcb6 Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Wed, 24 Apr 2024 08:11:46 -0700 Subject: [PATCH 07/13] Augmenting tests for code coverage --- tests/parameter/test_infer.py | 81 ++++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 29 deletions(-) diff --git a/tests/parameter/test_infer.py b/tests/parameter/test_infer.py index 6afc32f073a..227c0d44055 100644 --- a/tests/parameter/test_infer.py +++ b/tests/parameter/test_infer.py @@ -2,14 +2,14 @@ from typing import Any -import numpy as np import pytest -from qcodes.instrument import Instrument, InstrumentModule +from qcodes.instrument import Instrument, InstrumentBase, InstrumentModule from qcodes.parameters import DelegateParameter, ManualParameter, Parameter from qcodes.parameters.infer import ( InferAttrs, InferError, + _merge_user_and_class_attrs, get_parameter_chain, get_root_parameter, infer_channel, @@ -36,17 +36,18 @@ def __init__(self, name: str): self.module = DummyModule(name="module", parent=self) -class DummyDelegateInstrument(Instrument): +class DummyDelegateInstrument(InstrumentBase): def __init__(self, name: str): super().__init__(name=name) self.inst_delegate = DelegateParameter( name="inst_delegate", source=None, instrument=self, bind_to_instrument=True ) self.module = DummyDelegateModule(name="dummy_delegate_module", parent=self) + self.inst_base_parameter = ManualParameter("inst_base_parameter") class DummyDelegateModule(InstrumentModule): - def __init__(self, name: str, parent: Instrument): + def __init__(self, name: str, parent: InstrumentBase): super().__init__(name=name, parent=parent) self.chan_delegate = DelegateParameter( name="chan_delegate", source=None, instrument=self, bind_to_instrument=True @@ -108,8 +109,9 @@ def test_get_root_parameter_no_source(good_inst_delegates): good_inst_del_1.source = None - with pytest.raises(InferError): + with pytest.raises(InferError) as exc_info: get_root_parameter(good_inst_del_2) + assert "is not attached to a source" in str(exc_info.value) def test_get_root_parameter_no_user_attr(good_inst_delegates): @@ -121,8 +123,9 @@ def test_get_root_parameter_no_user_attr(good_inst_delegates): def test_get_root_parameter_none_user_attr(good_inst_delegates): _, _, good_inst_del_3 = good_inst_delegates good_inst_del_3.linked_parameter = None - with pytest.raises(InferError): + with pytest.raises(InferError) as exc_info: get_root_parameter(good_inst_del_3, "linked_parameter") + assert "is not attached to a source on attribute" in str(exc_info.value) def test_infer_instrument_valid(instrument_fixture, good_inst_delegates): @@ -137,8 +140,16 @@ def test_infer_instrument_no_instrument(instrument_fixture): no_inst_delegate = DelegateParameter( "no_inst_delegate", source=inst.bad_inst_parameter ) - with pytest.raises(InferError): + with pytest.raises(InferError) as exc_info: infer_instrument(no_inst_delegate) + assert "has no instrument" in str(exc_info.value) + + +def test_infer_instrument_root_instrument_base(): + delegate_inst = DummyDelegateInstrument("dummy_delegate_instrument") + + with pytest.raises(InferError): + infer_instrument(delegate_inst.inst_base_parameter) def test_infer_channel_valid(instrument_fixture): @@ -154,39 +165,45 @@ def test_infer_channel_no_channel(instrument_fixture): no_chan_delegate = DelegateParameter( "no_chan_delegate", source=inst.module.bad_chan_parameter ) - with pytest.raises(InferError): - infer_instrument(no_chan_delegate) + with pytest.raises(InferError) as exc_info: + infer_channel(no_chan_delegate) + assert "has no instrument" in str(exc_info.value) + + inst_but_not_chan_delegate = DelegateParameter( + "inst_but_not_chan_delegate", source=inst.good_inst_parameter + ) + with pytest.raises(InferError) as exc_info: + infer_channel(inst_but_not_chan_delegate) + assert "Could not determine a root instrument channel" in str(exc_info.value) def test_get_parameter_chain(instrument_fixture, good_inst_delegates): inst = instrument_fixture good_inst_del_1, good_inst_del_2, good_inst_del_3 = good_inst_delegates parameter_chain = get_parameter_chain(good_inst_del_3, "linked_parameter") - assert np.all( - [ - param in parameter_chain - for param in ( - inst.good_inst_parameter, - good_inst_del_1, - good_inst_del_2, - good_inst_del_3, - ) - ] - ) + assert set( + ( + inst.good_inst_parameter, + good_inst_del_1, + good_inst_del_2, + good_inst_del_3, + ) + ) == set(parameter_chain) # This is a broken chain. get_root_parameter would throw an InferError, but # get_parameter_chain should run successfully good_inst_del_1.source = None parameter_chain = get_parameter_chain(good_inst_del_3, "linked_parameter") - assert np.all( - [ - param in parameter_chain - for param in ( - good_inst_del_1, - good_inst_del_2, - good_inst_del_3, - ) - ] + assert set((good_inst_del_1, good_inst_del_2, good_inst_del_3)) == set( + parameter_chain + ) + + # Make the linked_parameter at the end of the chain + good_inst_del_3.linked_parameter = None + good_inst_del_1.source = good_inst_del_3 + parameter_chain = get_parameter_chain(good_inst_del_2, "linked_parameter") + assert set((good_inst_del_1, good_inst_del_2, good_inst_del_3)) == set( + parameter_chain ) @@ -201,3 +218,9 @@ def test_parameters_on_delegate_instruments(instrument_fixture, good_inst_delega assert infer_channel(delegate_inst.module.chan_delegate) is inst.module assert infer_instrument(delegate_inst.module.chan_delegate) is inst assert infer_instrument(delegate_inst.inst_delegate) is inst + + +def test_merge_user_and_class_attrs(): + InferAttrs.add_attr("attr1") + attr_set = _merge_user_and_class_attrs("attr2") + assert set(("attr1", "attr2")) == attr_set From fbd2a8a319210b59c42630c213bd7035c4413e9f Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Wed, 24 Apr 2024 08:38:48 -0700 Subject: [PATCH 08/13] Make InferAttrs.add_attrs take lists of attrs to add as well Add more code coverage in tests --- src/qcodes/parameters/infer.py | 6 ++++-- tests/parameter/test_infer.py | 27 ++++++++++++++++++++++++--- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/qcodes/parameters/infer.py b/src/qcodes/parameters/infer.py index 9259bfde7b7..70fc9fad35a 100644 --- a/src/qcodes/parameters/infer.py +++ b/src/qcodes/parameters/infer.py @@ -21,8 +21,10 @@ class InferAttrs: _known_attrs: ClassVar[set[str]] = set() @classmethod - def add_attr(cls, attr: str) -> None: - cls._known_attrs.add(attr) + def add_attrs(cls, attrs: str | Iterable[str]) -> None: + if isinstance(attrs, str): + attrs = (attrs,) + cls._known_attrs.update(set(attrs)) @classmethod def known_attrs(cls) -> tuple[str, ...]: diff --git a/tests/parameter/test_infer.py b/tests/parameter/test_infer.py index 227c0d44055..f779071f668 100644 --- a/tests/parameter/test_infer.py +++ b/tests/parameter/test_infer.py @@ -43,7 +43,9 @@ def __init__(self, name: str): name="inst_delegate", source=None, instrument=self, bind_to_instrument=True ) self.module = DummyDelegateModule(name="dummy_delegate_module", parent=self) - self.inst_base_parameter = ManualParameter("inst_base_parameter") + self.inst_base_parameter = ManualParameter( + "inst_base_parameter", instrument=self + ) class DummyDelegateModule(InstrumentModule): @@ -148,8 +150,9 @@ def test_infer_instrument_no_instrument(instrument_fixture): def test_infer_instrument_root_instrument_base(): delegate_inst = DummyDelegateInstrument("dummy_delegate_instrument") - with pytest.raises(InferError): + with pytest.raises(InferError) as exc_info: infer_instrument(delegate_inst.inst_base_parameter) + assert "Could not determine source instrument for parameter" in str(exc_info.value) def test_infer_channel_valid(instrument_fixture): @@ -221,6 +224,24 @@ def test_parameters_on_delegate_instruments(instrument_fixture, good_inst_delega def test_merge_user_and_class_attrs(): - InferAttrs.add_attr("attr1") + InferAttrs.add_attrs("attr1") attr_set = _merge_user_and_class_attrs("attr2") assert set(("attr1", "attr2")) == attr_set + + attr_set_list = _merge_user_and_class_attrs(("attr2", "attr3")) + assert set(("attr1", "attr2", "attr3")) == attr_set_list + + +def test_infer_attrs(): + InferAttrs.clear_attrs() + assert InferAttrs.known_attrs() == () + + InferAttrs.add_attrs("attr1") + assert set(InferAttrs.known_attrs()) == set(("attr1",)) + + InferAttrs.add_attrs("attr2") + InferAttrs.discard_attr("attr1") + assert set(InferAttrs.known_attrs()) == set(("attr2",)) + + InferAttrs.add_attrs(("attr1", "attr3")) + assert set(InferAttrs.known_attrs()) == set(("attr1", "attr2", "attr3")) From eaa29c5f6455e41aa13be65df79ed9e0e79a5be3 Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Wed, 24 Apr 2024 08:50:04 -0700 Subject: [PATCH 09/13] Fix missing updates --- tests/parameter/test_infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/parameter/test_infer.py b/tests/parameter/test_infer.py index f779071f668..79b489f0d91 100644 --- a/tests/parameter/test_infer.py +++ b/tests/parameter/test_infer.py @@ -102,7 +102,7 @@ def test_get_root_parameter_valid(instrument_fixture, good_inst_delegates): InferAttrs.clear_attrs() assert get_root_parameter(good_inst_del_3) is good_inst_del_3 - InferAttrs.add_attr("linked_parameter") + InferAttrs.add_attrs("linked_parameter") assert get_root_parameter(good_inst_del_3) is inst.good_inst_parameter @@ -133,7 +133,7 @@ def test_get_root_parameter_none_user_attr(good_inst_delegates): def test_infer_instrument_valid(instrument_fixture, good_inst_delegates): inst = instrument_fixture _, _, good_inst_del_3 = good_inst_delegates - InferAttrs.add_attr("linked_parameter") + InferAttrs.add_attrs("linked_parameter") assert infer_instrument(good_inst_del_3) is inst From 33141d2cbad45730fda45bca955a1da279faaf05 Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Thu, 25 Apr 2024 10:23:24 -0700 Subject: [PATCH 10/13] Applying changes from code review - Move to extensions module - Less verbose members of InferAttrs - rename to infer_instrument_module - Fix legacy import locations --- src/qcodes/extensions/__init__.py | 15 +++++++++++++++ src/qcodes/{parameters => extensions}/infer.py | 18 +++++++++++++----- tests/parameter/test_infer.py | 6 +++--- 3 files changed, 31 insertions(+), 8 deletions(-) rename src/qcodes/{parameters => extensions}/infer.py (91%) diff --git a/src/qcodes/extensions/__init__.py b/src/qcodes/extensions/__init__.py index 9fee703eeac..a18aba947c0 100644 --- a/src/qcodes/extensions/__init__.py +++ b/src/qcodes/extensions/__init__.py @@ -2,12 +2,27 @@ The extensions module contains smaller modules that extend the functionality of QCoDeS. These modules may import from all of QCoDeS but do not themselves get imported into QCoDeS. """ + from ._driver_test_case import DriverTestCase from ._log_export_info import log_dataset_export_info +from .infer import ( + InferAttrs, + InferError, + get_root_parameter, + infer_channel, + infer_instrument, + infer_instrument_module, +) from .installation import register_station_schema_with_vscode __all__ = [ "register_station_schema_with_vscode", "log_dataset_export_info", "DriverTestCase", + "InferAttrs", + "InferError", + "get_root_parameter", + "infer_channel", + "infer_instrument", + "infer_instrument_module", ] diff --git a/src/qcodes/parameters/infer.py b/src/qcodes/extensions/infer.py similarity index 91% rename from src/qcodes/parameters/infer.py rename to src/qcodes/extensions/infer.py index 70fc9fad35a..001ba2e3fbe 100644 --- a/src/qcodes/parameters/infer.py +++ b/src/qcodes/extensions/infer.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, ClassVar from qcodes.instrument import Instrument, InstrumentBase, InstrumentModule -from qcodes.instrument.parameter import DelegateParameter, Parameter +from qcodes.parameters import DelegateParameter, Parameter if TYPE_CHECKING: from collections.abc import Iterable @@ -21,7 +21,7 @@ class InferAttrs: _known_attrs: ClassVar[set[str]] = set() @classmethod - def add_attrs(cls, attrs: str | Iterable[str]) -> None: + def add(cls, attrs: str | Iterable[str]) -> None: if isinstance(attrs, str): attrs = (attrs,) cls._known_attrs.update(set(attrs)) @@ -31,11 +31,11 @@ def known_attrs(cls) -> tuple[str, ...]: return tuple(cls._known_attrs) @classmethod - def discard_attr(cls, attr: str) -> None: + def discard(cls, attr: str) -> None: cls._known_attrs.discard(attr) @classmethod - def clear_attrs(cls) -> None: + def clear(cls) -> None: cls._known_attrs = set() @@ -77,7 +77,7 @@ def infer_instrument( raise InferError(f"Could not determine source instrument for parameter {param}") -def infer_channel( +def infer_instrument_module( param: Parameter, alt_source_attrs: Sequence[str] | None = None, ) -> InstrumentModule: @@ -91,6 +91,14 @@ def infer_channel( ) +def infer_channel( + param: Parameter, + alt_source_attrs: Sequence[str] | None = None, +) -> InstrumentModule: + """An alias for infer_instrument_module""" + return infer_instrument_module(param, alt_source_attrs) + + def get_instrument_from_param( param: Parameter, ) -> InstrumentBase: diff --git a/tests/parameter/test_infer.py b/tests/parameter/test_infer.py index 79b489f0d91..86a762d9db5 100644 --- a/tests/parameter/test_infer.py +++ b/tests/parameter/test_infer.py @@ -4,9 +4,7 @@ import pytest -from qcodes.instrument import Instrument, InstrumentBase, InstrumentModule -from qcodes.parameters import DelegateParameter, ManualParameter, Parameter -from qcodes.parameters.infer import ( +from qcodes.extensions.infer import ( InferAttrs, InferError, _merge_user_and_class_attrs, @@ -15,6 +13,8 @@ infer_channel, infer_instrument, ) +from qcodes.instrument import Instrument, InstrumentBase, InstrumentModule +from qcodes.parameters import DelegateParameter, ManualParameter, Parameter class DummyModule(InstrumentModule): From b78cde3b3a5e3e84c3a237ceb34a9e4de68c3f24 Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Thu, 25 Apr 2024 10:58:07 -0700 Subject: [PATCH 11/13] Handle looped linking parameters Expand docstrings --- src/qcodes/extensions/infer.py | 91 +++++++++++++++++++++++++++++----- 1 file changed, 78 insertions(+), 13 deletions(-) diff --git a/src/qcodes/extensions/infer.py b/src/qcodes/extensions/infer.py index 001ba2e3fbe..781dda0e7e8 100644 --- a/src/qcodes/extensions/infer.py +++ b/src/qcodes/extensions/infer.py @@ -43,30 +43,53 @@ def get_root_parameter( param: Parameter, alt_source_attrs: Sequence[str] | None = None, ) -> Parameter: - """Return the root parameter in a chain of DelegateParameters or other linking Parameters""" - alt_source_attrs_set = _merge_user_and_class_attrs(alt_source_attrs) + """ + Return the root parameter in a chain of DelegateParameters or other linking Parameters - if isinstance(param, DelegateParameter): - if param.source is None: - raise InferError(f"Parameter {param} is not attached to a source") - return get_root_parameter(param.source) + This method calls get_parameter_chain and then checks for various error conditions + Args: + param: The DelegateParameter or other linking parameter to find the root parameter from + alt_source_attrs: The attribute names for custom linking parameters + + Raises: + InferError: If the linking parameters do not end with a non-linking parameter + InferError: If the chain of linking parameters loops on itself + """ + parameter_chain = get_parameter_chain(param, alt_source_attrs) + root_param = parameter_chain[-1] + + if root_param is parameter_chain[0]: + raise InferError(f"{param} generated a loop of linking parameters") + if isinstance(root_param, DelegateParameter): + raise InferError(f"Parameter {param} is not attached to a source") + + alt_source_attrs_set = _merge_user_and_class_attrs(alt_source_attrs) for alt_source_attr in alt_source_attrs_set: alt_source = getattr(param, alt_source_attr, DOES_NOT_EXIST) if alt_source is None: raise InferError( f"Parameter {param} is not attached to a source on attribute {alt_source_attr}" ) - elif isinstance(alt_source, Parameter): - return get_root_parameter(alt_source, alt_source_attrs=alt_source_attrs) - return param + return root_param def infer_instrument( param: Parameter, alt_source_attrs: Sequence[str] | None = None, ) -> InstrumentBase: - """Find the instrument that owns a parameter or delegate parameter.""" + """ + Find the instrument that owns a parameter or delegate parameter. + + Args: + param: The DelegateParameter or other linking parameter to find the instrument from + alt_source_attrs: The attribute names for custom linking parameters + + Raises: + InferError: If the linking parameters do not end with a non-linking parameter + InferError: If the instrument of the root parameter is None + InferError: If the instrument of the root parameter is not an instance of Instrument + """ root_param = get_root_parameter(param, alt_source_attrs=alt_source_attrs) instrument = get_instrument_from_param(root_param) if isinstance(instrument, InstrumentModule): @@ -81,7 +104,18 @@ def infer_instrument_module( param: Parameter, alt_source_attrs: Sequence[str] | None = None, ) -> InstrumentModule: - """Find the instrument module that owns a parameter or delegate parameter""" + """ + Find the instrument module that owns a parameter or delegate parameter + + Args: + param: The DelegateParameter or other linking parameter to find the instrument module from + alt_source_attrs: The attribute names for custom linking parameters + + Raises: + InferError: If the linking parameters do not end with a non-linking parameter + InferError: If the instrument module of the root parameter is None + InferError: If the instrument module of the root parameter is not an instance of InstrumentModule + """ root_param = get_root_parameter(param, alt_source_attrs=alt_source_attrs) channel = get_instrument_from_param(root_param) if isinstance(channel, InstrumentModule): @@ -102,6 +136,15 @@ def infer_channel( def get_instrument_from_param( param: Parameter, ) -> InstrumentBase: + """ + Return the instrument attribute from a parameter + + Args: + param: The parameter to get the instrument module from + + Raises: + InferError: If the parameter does not have an instrument + """ if param.instrument is not None: return param.instrument raise InferError(f"Parameter {param} has no instrument") @@ -111,7 +154,24 @@ def get_parameter_chain( param_chain: Parameter | Sequence[Parameter], alt_source_attrs: str | Sequence[str] | None = None, ) -> tuple[Parameter, ...]: - """Return the chain of DelegateParameters or other linking Parameters""" + """ + Return the chain of DelegateParameters or other linking Parameters + + This method traverses singly-linked parameters and returns the resulting chain + If the parameters loop, then the first and last linking parameters in the chain + will be identical. Otherwise, the chain starts with the initial argument passed + and ends when the chain terminates in either a non-linking parameter or a + linking parameter that links to None + + The search prioritizes the `source` attribute of DelegateParameters first, and + then looks for other linking attributes in undetermined order. + + Args: + param_chain: The initial linking parameter or a List linking parameters + from which to return the chain + alt_source_attrs: The attribute names for custom linking parameters + """ + alt_source_attrs_set = _merge_user_and_class_attrs(alt_source_attrs) if not isinstance(param_chain, Sequence): @@ -123,6 +183,8 @@ def get_parameter_chain( if param.source is None: return tuple(param_chain) mutable_param_chain.append(param.source) + if param.source in param_chain: # There is a loop in the links + return tuple(mutable_param_chain) return get_parameter_chain( mutable_param_chain, alt_source_attrs=alt_source_attrs, @@ -130,10 +192,12 @@ def get_parameter_chain( for alt_source_attr in alt_source_attrs_set: alt_source = getattr(param, alt_source_attr, DOES_NOT_EXIST) - if alt_source is None: + if alt_source is None: # Valid linking attribute, but no link parameter return tuple(param_chain) elif isinstance(alt_source, Parameter): mutable_param_chain.append(alt_source) + if alt_source in param_chain: # There is a loop in the links + return tuple(mutable_param_chain) return get_parameter_chain( mutable_param_chain, alt_source_attrs=alt_source_attrs, @@ -144,6 +208,7 @@ def get_parameter_chain( def _merge_user_and_class_attrs( alt_source_attrs: str | Sequence[str] | None = None, ) -> Iterable[str]: + """Merges user-supplied linking attributes with attributes from InferAttrs""" if alt_source_attrs is None: return InferAttrs.known_attrs() elif isinstance(alt_source_attrs, str): From 123a2ba2be520e91f45cfb16c56052235da1c7ae Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Thu, 25 Apr 2024 11:12:14 -0700 Subject: [PATCH 12/13] Extend tests to cover loops Fix bug from adding loop detection Move tests --- src/qcodes/extensions/infer.py | 2 +- tests/{parameter => extensions}/test_infer.py | 81 +++++++++++++------ 2 files changed, 59 insertions(+), 24 deletions(-) rename tests/{parameter => extensions}/test_infer.py (81%) diff --git a/src/qcodes/extensions/infer.py b/src/qcodes/extensions/infer.py index 781dda0e7e8..1937e20da6a 100644 --- a/src/qcodes/extensions/infer.py +++ b/src/qcodes/extensions/infer.py @@ -59,7 +59,7 @@ def get_root_parameter( parameter_chain = get_parameter_chain(param, alt_source_attrs) root_param = parameter_chain[-1] - if root_param is parameter_chain[0]: + if root_param is parameter_chain[0] and len(parameter_chain) > 1: raise InferError(f"{param} generated a loop of linking parameters") if isinstance(root_param, DelegateParameter): raise InferError(f"Parameter {param} is not attached to a source") diff --git a/tests/parameter/test_infer.py b/tests/extensions/test_infer.py similarity index 81% rename from tests/parameter/test_infer.py rename to tests/extensions/test_infer.py index 86a762d9db5..48c29d43a6d 100644 --- a/tests/parameter/test_infer.py +++ b/tests/extensions/test_infer.py @@ -2,6 +2,7 @@ from typing import Any +import numpy as np import pytest from qcodes.extensions.infer import ( @@ -67,7 +68,7 @@ def __init__( @pytest.fixture(name="instrument_fixture") def make_instrument_fixture(): inst = DummyInstrument("dummy_instrument") - InferAttrs.clear_attrs() + InferAttrs.clear() try: yield inst finally: @@ -99,10 +100,10 @@ def test_get_root_parameter_valid(instrument_fixture, good_inst_delegates): is inst.good_inst_parameter ) - InferAttrs.clear_attrs() + InferAttrs.clear() assert get_root_parameter(good_inst_del_3) is good_inst_del_3 - InferAttrs.add_attrs("linked_parameter") + InferAttrs.add("linked_parameter") assert get_root_parameter(good_inst_del_3) is inst.good_inst_parameter @@ -118,7 +119,7 @@ def test_get_root_parameter_no_source(good_inst_delegates): def test_get_root_parameter_no_user_attr(good_inst_delegates): _, _, good_inst_del_3 = good_inst_delegates - InferAttrs.clear_attrs() + InferAttrs.clear() assert get_root_parameter(good_inst_del_3, "external_parameter") is good_inst_del_3 @@ -133,7 +134,7 @@ def test_get_root_parameter_none_user_attr(good_inst_delegates): def test_infer_instrument_valid(instrument_fixture, good_inst_delegates): inst = instrument_fixture _, _, good_inst_del_3 = good_inst_delegates - InferAttrs.add_attrs("linked_parameter") + InferAttrs.add("linked_parameter") assert infer_instrument(good_inst_del_3) is inst @@ -184,29 +185,40 @@ def test_get_parameter_chain(instrument_fixture, good_inst_delegates): inst = instrument_fixture good_inst_del_1, good_inst_del_2, good_inst_del_3 = good_inst_delegates parameter_chain = get_parameter_chain(good_inst_del_3, "linked_parameter") - assert set( - ( - inst.good_inst_parameter, - good_inst_del_1, - good_inst_del_2, - good_inst_del_3, - ) - ) == set(parameter_chain) + expected_chain = ( + good_inst_del_3, + good_inst_del_2, + good_inst_del_1, + inst.good_inst_parameter, + ) + assert np.all( + [parameter_chain[i] is param for i, param in enumerate(expected_chain)] + ) # This is a broken chain. get_root_parameter would throw an InferError, but # get_parameter_chain should run successfully good_inst_del_1.source = None parameter_chain = get_parameter_chain(good_inst_del_3, "linked_parameter") - assert set((good_inst_del_1, good_inst_del_2, good_inst_del_3)) == set( - parameter_chain + expected_chain = ( + good_inst_del_3, + good_inst_del_2, + good_inst_del_1, + ) + assert np.all( + [parameter_chain[i] is param for i, param in enumerate(expected_chain)] ) # Make the linked_parameter at the end of the chain good_inst_del_3.linked_parameter = None good_inst_del_1.source = good_inst_del_3 parameter_chain = get_parameter_chain(good_inst_del_2, "linked_parameter") - assert set((good_inst_del_1, good_inst_del_2, good_inst_del_3)) == set( - parameter_chain + expected_chain = ( + good_inst_del_2, + good_inst_del_1, + good_inst_del_3, + ) + assert np.all( + [parameter_chain[i] is param for i, param in enumerate(expected_chain)] ) @@ -224,7 +236,7 @@ def test_parameters_on_delegate_instruments(instrument_fixture, good_inst_delega def test_merge_user_and_class_attrs(): - InferAttrs.add_attrs("attr1") + InferAttrs.add("attr1") attr_set = _merge_user_and_class_attrs("attr2") assert set(("attr1", "attr2")) == attr_set @@ -233,15 +245,38 @@ def test_merge_user_and_class_attrs(): def test_infer_attrs(): - InferAttrs.clear_attrs() + InferAttrs.clear() assert InferAttrs.known_attrs() == () - InferAttrs.add_attrs("attr1") + InferAttrs.add("attr1") assert set(InferAttrs.known_attrs()) == set(("attr1",)) - InferAttrs.add_attrs("attr2") - InferAttrs.discard_attr("attr1") + InferAttrs.add("attr2") + InferAttrs.discard("attr1") assert set(InferAttrs.known_attrs()) == set(("attr2",)) - InferAttrs.add_attrs(("attr1", "attr3")) + InferAttrs.add(("attr1", "attr3")) assert set(InferAttrs.known_attrs()) == set(("attr1", "attr2", "attr3")) + + +def test_get_parameter_chain_with_loops(good_inst_delegates): + good_inst_del_1, good_inst_del_2, good_inst_del_3 = good_inst_delegates + good_inst_del_1.source = good_inst_del_3 + parameter_chain = get_parameter_chain(good_inst_del_3, "linked_parameter") + expected_chain = ( + good_inst_del_3, + good_inst_del_2, + good_inst_del_1, + good_inst_del_3, + ) + assert np.all( + [parameter_chain[i] is param for i, param in enumerate(expected_chain)] + ) + + +def test_get_root_parameter_with_loops(good_inst_delegates): + good_inst_del_1, _, good_inst_del_3 = good_inst_delegates + good_inst_del_1.source = good_inst_del_3 + with pytest.raises(InferError) as exc_info: + get_root_parameter(good_inst_del_3, "linked_parameter") + assert "generated a loop of linking parameters" in str(exc_info.value) From da1645a3f3e8b6cd9c17eaa40ed1c740f5235131 Mon Sep 17 00:00:00 2001 From: Samantha Ho Date: Thu, 25 Apr 2024 11:29:07 -0700 Subject: [PATCH 13/13] Modify loop test for code coverage --- tests/extensions/test_infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/extensions/test_infer.py b/tests/extensions/test_infer.py index 48c29d43a6d..d0453b96e12 100644 --- a/tests/extensions/test_infer.py +++ b/tests/extensions/test_infer.py @@ -275,8 +275,8 @@ def test_get_parameter_chain_with_loops(good_inst_delegates): def test_get_root_parameter_with_loops(good_inst_delegates): - good_inst_del_1, _, good_inst_del_3 = good_inst_delegates + good_inst_del_1, good_inst_del_2, good_inst_del_3 = good_inst_delegates good_inst_del_1.source = good_inst_del_3 with pytest.raises(InferError) as exc_info: - get_root_parameter(good_inst_del_3, "linked_parameter") + get_root_parameter(good_inst_del_2, "linked_parameter") assert "generated a loop of linking parameters" in str(exc_info.value)