diff --git a/qiskit_experiments/curve_analysis/curve_analysis.py b/qiskit_experiments/curve_analysis/curve_analysis.py index 815e7ac9c3..ed9b5e26ba 100644 --- a/qiskit_experiments/curve_analysis/curve_analysis.py +++ b/qiskit_experiments/curve_analysis/curve_analysis.py @@ -239,6 +239,7 @@ class AnalysisExample(CurveAnalysis): def __init__(self): """Initialize data fields that are privately accessed by methods.""" + super().__init__() #: Dict[str, Any]: Experiment metadata self.__experiment_metadata = None @@ -821,35 +822,12 @@ def _get_option(self, arg_name: str) -> Any: ) from ex def _run_analysis( - self, experiment_data: ExperimentData, **options + self, experiment_data: ExperimentData ) -> Tuple[List[AnalysisResultData], List["pyplot.Figure"]]: - """Run analysis on circuit data. - - Args: - experiment_data: the experiment data to analyze. - options: kwarg options for analysis function. - - Returns: - tuple: A pair ``(analysis_results, figures)`` where ``analysis_results`` - is a list of :class:`AnalysisResultData` objects, and ``figures`` - is a list of any figures for the experiment. - - Raises: - AnalysisError: If the analysis fails. - DataProcessorError: When data processing failed. - """ - # # 1. Parse arguments # - - # Pop arguments that are not given to the fitter, - # and update class attributes with the arguments that are given to the fitter - # (arguments that have matching attributes in the class) - analysis_options = self._default_options().__dict__ - analysis_options.update(options) - - extra_options = self._arg_parse(**analysis_options) + extra_options = self._arg_parse(**self.options.__dict__) # Update all fit functions in the series definitions if fixed parameter is defined. # Fixed parameters should be provided by the analysis options. diff --git a/qiskit_experiments/framework/__init__.py b/qiskit_experiments/framework/__init__.py index f9fa5e76f5..3ec1b6913d 100644 --- a/qiskit_experiments/framework/__init__.py +++ b/qiskit_experiments/framework/__init__.py @@ -209,6 +209,7 @@ FitVal AnalysisResultData ExperimentConfig + AnalysisConfig ExperimentEncoder ExperimentDecoder @@ -238,7 +239,8 @@ from qiskit_experiments.database_service.db_analysis_result import DbAnalysisResultV1 from qiskit_experiments.database_service.db_fitval import FitVal from .base_analysis import BaseAnalysis -from .base_experiment import BaseExperiment, ExperimentConfig +from .base_experiment import BaseExperiment +from .configs import ExperimentConfig, AnalysisConfig from .analysis_result_data import AnalysisResultData from .experiment_data import ExperimentData from .composite import ( diff --git a/qiskit_experiments/framework/base_analysis.py b/qiskit_experiments/framework/base_analysis.py index eea12cb85d..e6073e34e6 100644 --- a/qiskit_experiments/framework/base_analysis.py +++ b/qiskit_experiments/framework/base_analysis.py @@ -14,16 +14,20 @@ """ from abc import ABC, abstractmethod -from typing import List, Tuple +import copy +from collections import OrderedDict +from typing import List, Tuple, Union, Dict from qiskit_experiments.database_service.device_component import Qubit from qiskit_experiments.framework import Options +from qiskit_experiments.framework.store_init_args import StoreInitArgs from qiskit_experiments.framework.experiment_data import ExperimentData +from qiskit_experiments.framework.configs import AnalysisConfig from qiskit_experiments.framework.analysis_result_data import AnalysisResultData from qiskit_experiments.database_service import DbAnalysisResultV1 -class BaseAnalysis(ABC): +class BaseAnalysis(ABC, StoreInitArgs): """Abstract base class for analyzing Experiment data. The data produced by experiments (i.e. subclasses of BaseExperiment) @@ -32,17 +36,74 @@ class BaseAnalysis(ABC): For example, an analysis may perform some data processing of the measured data and a fit to a function to extract a parameter. - When designing Analysis subclasses default values for any kwarg - analysis options of the `run` method should be set by overriding - the `_default_options` class method. When calling `run` these - default values will be combined with all other option kwargs in the - run method and passed to the `_run_analysis` function. + Analysis subclasses must implement the abstract method `_run_analysis`. + This method should not have side-effects on the analysis class itself + since it could potentially be called asynchronously in multiple threads. + Any configurable option values should be specified in the `_default_options` + class method. These values can be overriden by a user by calling the + `set_options` method or for a single-run can be specified by passing kwarg + options to the :meth:`run` method. """ + def __init__(self): + """Initialize the analysis object.""" + # Analysis options + self._options = self._default_options() + + # Store keys of non-default options + self._set_options = set() + + def config(self) -> AnalysisConfig: + """Return the config dataclass for this analysis""" + args = tuple(getattr(self, "__init_args__", OrderedDict()).values()) + kwargs = dict(getattr(self, "__init_kwargs__", OrderedDict())) + # Only store non-default valued options + options = dict((key, getattr(self._options, key)) for key in self._set_options) + return AnalysisConfig( + cls=type(self), + args=args, + kwargs=kwargs, + options=options, + ) + + @classmethod + def from_config(cls, config: Union[AnalysisConfig, Dict]) -> "BaseAnalysis": + """Initialize an analysis class from analysis config""" + if isinstance(config, dict): + config = AnalysisConfig(**dict) + ret = cls(*config.args, **config.kwargs) + if config.options: + ret.set_options(**config.options) + return ret + + def copy(self) -> "BaseAnalysis": + """Return a copy of the analysis""" + # We want to avoid a deep copy be default for performance so we + # need to also copy the Options structures so that if they are + # updated on the copy they don't effect the original. + ret = copy.copy(self) + ret._options = copy.copy(self._options) + ret._set_options = copy.copy(self._set_options) + return ret + @classmethod def _default_options(cls) -> Options: return Options() + @property + def options(self) -> Options: + """Return the analysis options for :meth:`run` method.""" + return self._options + + def set_options(self, **fields): + """Set the analysis options for :meth:`run` method. + + Args: + fields: The fields to update the options + """ + self._options.update_options(**fields) + self._set_options = self._set_options.union(fields) + def run( self, experiment_data: ExperimentData, @@ -96,16 +157,20 @@ def run( else: experiment_components = [] - # Get analysis options - analysis_options = self._default_options() - analysis_options.update_options(**options) - analysis_options = analysis_options.__dict__ + # Set Analysis options + if not options: + analysis = self + else: + analysis = self.copy() + analysis.set_options(**options) def run_analysis(expdata): - results, figures = self._run_analysis(expdata, **analysis_options) + results, figures = analysis._run_analysis(expdata) # Add components analysis_results = [ - self._format_analysis_result(result, expdata.experiment_id, experiment_components) + analysis._format_analysis_result( + result, expdata.experiment_id, experiment_components + ) for result in results ] # Update experiment data with analysis results @@ -139,15 +204,13 @@ def _format_analysis_result(self, data, experiment_id, experiment_components=Non @abstractmethod def _run_analysis( - self, experiment_data: ExperimentData, **options + self, + experiment_data: ExperimentData, ) -> Tuple[List[AnalysisResultData], List["matplotlib.figure.Figure"]]: """Run analysis on circuit data. Args: experiment_data: the experiment data to analyze. - options: additional options for analysis. By default the fields and - values in :meth:`options` are used and any provided values - can override these. Returns: A pair ``(analysis_results, figures)`` where ``analysis_results`` @@ -157,4 +220,12 @@ def _run_analysis( Raises: AnalysisError: if the analysis fails. """ + # NOTE: passing kwarg options to _run_analysis should be removed once pass + + def __json_encode__(self): + return self.config() + + @classmethod + def __json_decode__(cls, value): + return cls.from_config(value) diff --git a/qiskit_experiments/framework/base_experiment.py b/qiskit_experiments/framework/base_experiment.py index b79ed0ecf3..c5ffeeff73 100644 --- a/qiskit_experiments/framework/base_experiment.py +++ b/qiskit_experiments/framework/base_experiment.py @@ -15,7 +15,6 @@ from abc import ABC, abstractmethod import copy -import dataclasses from collections import OrderedDict from typing import Sequence, Optional, Tuple, List, Dict, Union, Any @@ -28,60 +27,7 @@ from qiskit.providers.options import Options from qiskit_experiments.framework.store_init_args import StoreInitArgs from qiskit_experiments.framework.experiment_data import ExperimentData -from qiskit_experiments.version import __version__ - - -@dataclasses.dataclass(frozen=True) -class ExperimentConfig: - """Store configuration settings for an Experiment - - This stores the current configuration of a - :class:~qiskit_experiments.framework.BaseExperiment` and - can be used to reconstruct the experiment using either the - :meth:`experiment` property if the experiment class type is - currently stored, or the - :meth:~qiskit_experiments.framework.BaseExperiment.from_config` - class method of the appropriate experiment. - """ - - cls: type = None - args: Tuple[Any] = dataclasses.field(default_factory=tuple) - kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - experiment_options: Dict[str, Any] = dataclasses.field(default_factory=dict) - transpile_options: Dict[str, Any] = dataclasses.field(default_factory=dict) - run_options: Dict[str, Any] = dataclasses.field(default_factory=dict) - version: str = __version__ - - def experiment(self) -> "BaseExperiment": - """Return the experiment constructed from this config. - - Returns: - The experiment reconstructed from the config. - - Raises: - QiskitError: if the experiment class is not stored, - was not successful deserialized, or reconstruction - of the experiment fails. - """ - cls = self.cls - if cls is None: - raise QiskitError("No experiment class in experiment config") - if isinstance(cls, dict): - raise QiskitError( - "Unable to load experiment class. Try manually loading " - "experiment using `Experiment.from_config(config)` instead." - ) - try: - return cls.from_config(self) - except Exception as ex: - msg = "Unable to construct experiments from config." - if cls.version != __version__: - msg += ( - f" Note that config version ({cls.version}) differs from the current" - f" qiskit-experiments version ({__version__}). You could try" - " installing a compatible qiskit-experiments version." - ) - raise QiskitError("{}\nError Message:\n{}".format(msg, str(ex))) from ex +from qiskit_experiments.framework.configs import ExperimentConfig class BaseExperiment(ABC, StoreInitArgs): @@ -184,6 +130,11 @@ def copy(self) -> "BaseExperiment": ret._run_options = copy.copy(self._run_options) ret._transpile_options = copy.copy(self._transpile_options) ret._analysis_options = copy.copy(self._analysis_options) + + ret._set_experiment_options = copy.copy(self._set_experiment_options) + ret._set_transpile_options = copy.copy(self._set_transpile_options) + ret._set_run_options = copy.copy(self._set_run_options) + ret._set_analysis_options = copy.copy(self._set_analysis_options) return ret def config(self) -> ExperimentConfig: diff --git a/qiskit_experiments/framework/composite/composite_analysis.py b/qiskit_experiments/framework/composite/composite_analysis.py index 6f27489f9d..fc4f53f929 100644 --- a/qiskit_experiments/framework/composite/composite_analysis.py +++ b/qiskit_experiments/framework/composite/composite_analysis.py @@ -45,23 +45,7 @@ class CompositeAnalysis(BaseAnalysis): reconstructed from the parent composite experiment data. """ - # pylint: disable = arguments-differ - def _run_analysis(self, experiment_data: ExperimentData, **options): - """Run analysis on composite experiment circuit data. - - Args: - experiment_data: the experiment data to analyze. - options: kwarg options for analysis function. - - Returns: - tuple: A pair ``(analysis_results, figures)`` where ``analysis_results`` - is a list of :class:`AnalysisResultData` objects, and ``figures`` - is a list of any figures for the experiment. - - Raises: - QiskitError: if analysis is attempted on non-composite - experiment data. - """ + def _run_analysis(self, experiment_data: ExperimentData): # Extract job metadata for the component experiments so it can be added # to the child experiment data incase it is required by the child experiments # analysis classes diff --git a/qiskit_experiments/framework/configs.py b/qiskit_experiments/framework/configs.py new file mode 100644 index 0000000000..07e84c037f --- /dev/null +++ b/qiskit_experiments/framework/configs.py @@ -0,0 +1,124 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2021. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Experiment and analysis config dataclasses. +""" + +import dataclasses +from typing import Tuple, Dict, Any + +from qiskit.exceptions import QiskitError +from qiskit_experiments.version import __version__ + + +@dataclasses.dataclass(frozen=True) +class ExperimentConfig: + """Store configuration settings for an Experiment + + This stores the current configuration of a + :class:~qiskit_experiments.framework.BaseExperiment` and + can be used to reconstruct the experiment using either the + :meth:`experiment` property if the experiment class type is + currently stored, or the + :meth:~qiskit_experiments.framework.BaseExperiment.from_config` + class method of the appropriate experiment. + """ + + cls: type = None + args: Tuple[Any] = dataclasses.field(default_factory=tuple) + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + experiment_options: Dict[str, Any] = dataclasses.field(default_factory=dict) + transpile_options: Dict[str, Any] = dataclasses.field(default_factory=dict) + run_options: Dict[str, Any] = dataclasses.field(default_factory=dict) + version: str = __version__ + + def experiment(self): + """Return the experiment constructed from this config. + + Returns: + BaseExperiment: The experiment reconstructed from the config. + + Raises: + QiskitError: if the experiment class is not stored, + was not successful deserialized, or reconstruction + of the experiment fails. + """ + cls = self.cls + if cls is None: + raise QiskitError("No experiment class in experiment config") + if isinstance(cls, dict): + raise QiskitError( + "Unable to load experiment class. Try manually loading " + "experiment using `Experiment.from_config(config)` instead." + ) + try: + return cls.from_config(self) + except Exception as ex: + msg = "Unable to construct experiments from config." + if cls.version != __version__: + msg += ( + f" Note that config version ({cls.version}) differs from the current" + f" qiskit-experiments version ({__version__}). You could try" + " installing a compatible qiskit-experiments version." + ) + raise QiskitError("{}\nError Message:\n{}".format(msg, str(ex))) from ex + + +@dataclasses.dataclass(frozen=True) +class AnalysisConfig: + """Store configuration settings for Analysis + + This stores the current configuration of a + :class:~qiskit_experiments.framework.BaseAnalysis` and + can be used to reconstruct the analysis class using either the + :meth:`analysis` property if the analysis class type is + currently stored, or the + :meth:~qiskit_experiments.framework.BaseAnalysis.from_config` + class method of the appropriate experiment. + """ + + cls: type = None + args: Tuple[Any] = dataclasses.field(default_factory=tuple) + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + options: Dict[str, Any] = dataclasses.field(default_factory=dict) + version: str = __version__ + + def analysis(self): + """Return the analysis class constructed from this config. + + Returns: + BaseAnalysis: The analysis reconstructed from the config. + + Raises: + QiskitError: if the analysis class is not stored, + was not successful deserialized, or reconstruction + of the analysis class fails. + """ + cls = self.cls + if cls is None: + raise QiskitError("No analysis class in analysis config") + if isinstance(cls, dict): + raise QiskitError( + "Unable to load analysis class. Try manually loading " + "analysis using `Analysis.from_config(config)` instead." + ) + try: + return cls.from_config(self) + except Exception as ex: + msg = "Unable to construct analysis from config." + if cls.version != __version__: + msg += ( + f" Note that config version ({cls.version}) differs from the current" + f" qiskit-experiments version ({__version__}). You could try" + " installing a compatible qiskit-experiments version." + ) + raise QiskitError("{}\nError Message:\n{}".format(msg, str(ex))) from ex diff --git a/qiskit_experiments/library/characterization/analysis/readout_angle_analysis.py b/qiskit_experiments/library/characterization/analysis/readout_angle_analysis.py index b850ed1f49..503d9e8529 100644 --- a/qiskit_experiments/library/characterization/analysis/readout_angle_analysis.py +++ b/qiskit_experiments/library/characterization/analysis/readout_angle_analysis.py @@ -23,8 +23,7 @@ class ReadoutAngleAnalysis(BaseAnalysis): A class to analyze readout angle experiments """ - # pylint: disable=unused-argument - def _run_analysis(self, experiment_data, **kwargs): + def _run_analysis(self, experiment_data): angles = [] for i in range(2): center = complex(*experiment_data.data(i)["memory"][0]) diff --git a/qiskit_experiments/library/quantum_volume/qv_analysis.py b/qiskit_experiments/library/quantum_volume/qv_analysis.py index f76bf44457..f5a41e3f2d 100644 --- a/qiskit_experiments/library/quantum_volume/qv_analysis.py +++ b/qiskit_experiments/library/quantum_volume/qv_analysis.py @@ -19,7 +19,7 @@ from typing import Optional import numpy as np -from qiskit_experiments.framework import BaseAnalysis, AnalysisResultData, FitVal +from qiskit_experiments.framework import BaseAnalysis, AnalysisResultData, Options, FitVal from qiskit_experiments.curve_analysis import plot_scatter, plot_errorbar @@ -37,26 +37,20 @@ class QuantumVolumeAnalysis(BaseAnalysis): is the success probability. """ - # pylint: disable = arguments-differ - def _run_analysis( - self, - experiment_data, - plot: bool = True, - ax: Optional["matplotlib.pyplot.AxesSubplot"] = None, - ): - """Run analysis on circuit data. - - Args: - experiment_data (ExperimentData): the experiment data to analyze. - plot (bool): If True generate a plot of fitted data. - ax (AxesSubplot): Optional, matplotlib axis to add plot to. + @classmethod + def _default_options(cls) -> Options: + """Return default analysis options. - Returns: - tuple: A pair ``(result_data figures)`` where - ``result_data`` is a list of - :class:`AnalysisResultData` objects, and ``figures`` may be - None, a single figure, or a list of figures. + Analysis Options: + plot (bool): Set ``True`` to create figure for fit result. + ax(AxesSubplot): Optional. A matplotlib axis object to draw. """ + options = super()._default_options() + options.plot = True + options.ax = None + return options + + def _run_analysis(self, experiment_data): depth = experiment_data.experiment.num_qubits data = experiment_data.data() num_trials = len(data) @@ -72,8 +66,8 @@ def _run_analysis( hop_result, qv_result = self._calc_quantum_volume(heavy_output_prob_exp, depth, num_trials) - if plot: - ax = self._format_plot(hop_result, ax=ax) + if self.options.plot: + ax = self._format_plot(hop_result, ax=self.options.ax) figures = [ax.get_figure()] else: figures = None diff --git a/qiskit_experiments/library/tomography/tomography_analysis.py b/qiskit_experiments/library/tomography/tomography_analysis.py index 5822132d51..f5d95c7bf1 100644 --- a/qiskit_experiments/library/tomography/tomography_analysis.py +++ b/qiskit_experiments/library/tomography/tomography_analysis.py @@ -68,6 +68,9 @@ def _default_options(cls) -> Options: This can be a string to select one of the built-in fitters, or a callable to supply a custom fitter function. See the `Fitter Functions` section for additional information. + fitter_options (dict): Any addition kwarg options to be supplied to the fitter + function. For documentation of available kargs refer to the fitter function + documentation. rescale_positive (bool): If True rescale the state returned by the fitter to be positive-semidefinite. See the `PSD Rescaling` section for additional information (Default: True). @@ -75,19 +78,16 @@ def _default_options(cls) -> Options: have either trace 1 for :class:`~qiskit.quantum_info.DensityMatrix`, or trace dim for :class:`~qiskit.quantum_info.Choi` matrices (Default: True). target (Any): depends on subclass. - kwargs: will be supplied to the fitter function, for documentation of available - args refer to the fitter function documentation. - """ options = super()._default_options() options.measurement_basis = None options.preparation_basis = None options.fitter = "linear_inversion" + options.fitter_options = {} options.rescale_positive = True options.rescale_trace = True options.target = "default" - return options @classmethod @@ -101,26 +101,20 @@ def _get_fitter(cls, fitter): return cls._builtin_fitters[fitter] raise AnalysisError(f"Unrecognized tomography fitter {fitter}") - def _run_analysis(self, experiment_data, **options): + def _run_analysis(self, experiment_data): # Extract tomography measurement data outcome_data, shot_data, measurement_data, preparation_data = self._fitter_data( experiment_data.data() ) - # Get tomography options - measurement_basis = options.pop("measurement_basis") - preparation_basis = options.pop("preparation_basis", None) - rescale_positive = options.pop("rescale_positive") - rescale_trace = options.pop("rescale_trace") - target_state = options.pop("target") - - # Get target state from circuit metadata + # Get target state + target_state = self.options.target if target_state == "default": metadata = experiment_data.metadata target_state = metadata.get("target", None) # Get tomography fitter function - fitter = self._get_fitter(options.pop("fitter", None)) + fitter = self._get_fitter(self.options.fitter) try: t_fitter_start = time.time() state, fitter_metadata = fitter( @@ -128,14 +122,14 @@ def _run_analysis(self, experiment_data, **options): shot_data, measurement_data, preparation_data, - measurement_basis, - preparation_basis, - **options, + self.options.measurement_basis, + self.options.preparation_basis, + **self.options.fitter_options, ) t_fitter_stop = time.time() if fitter_metadata is None: fitter_metadata = {} - state = Choi(state) if preparation_basis else DensityMatrix(state) + state = Choi(state) if self.options.preparation_basis else DensityMatrix(state) fitter_metadata["fitter"] = fitter.__name__ fitter_metadata["fitter_time"] = t_fitter_stop - t_fitter_start @@ -143,9 +137,9 @@ def _run_analysis(self, experiment_data, **options): state, metadata=fitter_metadata, target_state=target_state, - rescale_positive=rescale_positive, - rescale_trace=rescale_trace, - qpt=bool(preparation_basis), + rescale_positive=self.options.rescale_positive, + rescale_trace=self.options.rescale_trace, + qpt=bool(self.options.preparation_basis), ) except AnalysisError as ex: diff --git a/test/curve_analysis/test_curve_fit.py b/test/curve_analysis/test_curve_fit.py index 4133a9ba03..70ad3fb8b4 100644 --- a/test/curve_analysis/test_curve_fit.py +++ b/test/curve_analysis/test_curve_fit.py @@ -318,11 +318,12 @@ def test_run_single_curve_analysis(self): xvals=self.xvalues, param_dict={"amp": ref_p0, "lamb": ref_p1, "x0": ref_p2, "baseline": ref_p3}, ) - default_opts = analysis._default_options() - default_opts.p0 = {"p0": ref_p0, "p1": ref_p1, "p2": ref_p2, "p3": ref_p3} - default_opts.result_parameters = [ParameterRepr("p1", "parameter_name", "unit")] + analysis.set_options( + p0={"p0": ref_p0, "p1": ref_p1, "p2": ref_p2, "p3": ref_p3}, + result_parameters=[ParameterRepr("p1", "parameter_name", "unit")], + ) - results, _ = analysis._run_analysis(test_data, **default_opts.__dict__) + results, _ = analysis._run_analysis(test_data) result = results[0] ref_popt = np.asarray([ref_p0, ref_p1, ref_p2, ref_p3]) @@ -361,13 +362,14 @@ def test_run_single_curve_fail(self): xvals=self.xvalues, param_dict={"amp": ref_p0, "lamb": ref_p1, "x0": ref_p2, "baseline": ref_p3}, ) - default_opts = analysis._default_options() - default_opts.p0 = {"p0": ref_p0, "p1": ref_p1, "p2": ref_p2, "p3": ref_p3} - default_opts.bounds = {"p0": [-10, 0], "p1": [-10, 0], "p2": [-10, 0], "p3": [-10, 0]} - default_opts.return_data_points = True + analysis.set_options( + p0={"p0": ref_p0, "p1": ref_p1, "p2": ref_p2, "p3": ref_p3}, + bounds={"p0": [-10, 0], "p1": [-10, 0], "p2": [-10, 0], "p3": [-10, 0]}, + return_data_points=True, + ) # Try to fit with infeasible parameter boundary. This should fail. - results, _ = analysis._run_analysis(test_data, **default_opts.__dict__) + results, _ = analysis._run_analysis(test_data) # This returns only data point entry self.assertEqual(len(results), 1) @@ -417,10 +419,10 @@ def test_run_two_curves_with_same_fitfunc(self): for datum in test_data1.data(): test_data0.add_data(datum) - default_opts = analysis._default_options() - default_opts.p0 = {"p0": ref_p0, "p1": ref_p1, "p2": ref_p2, "p3": ref_p3, "p4": ref_p4} - - results, _ = analysis._run_analysis(test_data0, **default_opts.__dict__) + analysis.set_options( + p0={"p0": ref_p0, "p1": ref_p1, "p2": ref_p2, "p3": ref_p3, "p4": ref_p4} + ) + results, _ = analysis._run_analysis(test_data0) result = results[0] ref_popt = np.asarray([ref_p0, ref_p1, ref_p2, ref_p3, ref_p4]) @@ -471,10 +473,8 @@ def test_run_two_curves_with_two_fitfuncs(self): for datum in test_data1.data(): test_data0.add_data(datum) - default_opts = analysis._default_options() - default_opts.p0 = {"p0": ref_p0, "p1": ref_p1, "p2": ref_p2, "p3": ref_p3} - - results, _ = analysis._run_analysis(test_data0, **default_opts.__dict__) + analysis.set_options(p0={"p0": ref_p0, "p1": ref_p1, "p2": ref_p2, "p3": ref_p3}) + results, _ = analysis._run_analysis(test_data0) result = results[0] ref_popt = np.asarray([ref_p0, ref_p1, ref_p2, ref_p3]) @@ -507,11 +507,12 @@ def test_run_fixed_parameters(self): param_dict={"amp": ref_p0, "freq": ref_p1, "phase": ref_p2, "baseline": ref_p3}, ) - default_opts = analysis._default_options() - default_opts.p0 = {"p0": ref_p0, "p1": ref_p1, "p3": ref_p3} - default_opts.fixed_p2 = ref_p2 + analysis.set_options( + p0={"p0": ref_p0, "p1": ref_p1, "p3": ref_p3}, + fixed_p2=ref_p2, + ) - results, _ = analysis._run_analysis(test_data, **default_opts.__dict__) + results, _ = analysis._run_analysis(test_data) result = results[0] ref_popt = np.asarray([ref_p0, ref_p1, ref_p3]) @@ -543,14 +544,10 @@ def test_fixed_param_is_missing(self): xvals=self.xvalues, param_dict={"amp": ref_p0, "freq": ref_p1, "phase": ref_p2, "baseline": ref_p3}, ) - - default_opts = analysis._default_options() - # do not define fixed_p2 here - default_opts.p0 = {"p0": ref_p0, "p1": ref_p1, "p3": ref_p3} - + analysis.set_options(p0={"p0": ref_p0, "p1": ref_p1, "p3": ref_p3}) with self.assertRaises(AnalysisError): - analysis._run_analysis(test_data, **default_opts.__dict__) + analysis._run_analysis(test_data) class TestFitOptions(QiskitExperimentsTestCase): diff --git a/test/fake_experiment.py b/test/fake_experiment.py index a081959b00..fc2dc55416 100644 --- a/test/fake_experiment.py +++ b/test/fake_experiment.py @@ -21,6 +21,10 @@ class FakeAnalysis(BaseAnalysis): Dummy analysis class for test purposes only. """ + def __init__(self, **kwargs): + super().__init__() + self._kwargs = kwargs + def _run_analysis(self, experiment_data, **options): seed = options.get("seed", None) rng = np.random.default_rng(seed=seed) diff --git a/test/test_framework.py b/test/test_framework.py index 7da0f2b2ea..488dd655a9 100644 --- a/test/test_framework.py +++ b/test/test_framework.py @@ -78,3 +78,29 @@ def test_analysis_replace_results_false(self): self.assertNotEqual(expdata1, expdata2) self.assertNotEqual(expdata1.experiment_id, expdata2.experiment_id) self.assertNotEqual(expdata1.analysis_results(), expdata2.analysis_results()) + + def test_analysis_config(self): + """Test analysis config dataclass""" + analysis = FakeAnalysis(arg1=10, arg2=20) + analysis.set_options(option1=False, option2=True) + config = analysis.config() + loaded = config.analysis() + self.assertEqual(analysis.config(), loaded.config()) + self.assertEqual(analysis.options, loaded.options) + + def test_analysis_from_config(self): + """Test analysis config dataclass""" + analysis = FakeAnalysis(arg1=10, arg2=20) + analysis.set_options(option1=False, option2=True) + config = analysis.config() + loaded = FakeAnalysis.from_config(config) + self.assertEqual(config, loaded.config()) + + def test_analysis_runtime_opts(self): + """Test runtime options don't modify instance""" + opts = {"opt1": False, "opt2": False} + run_opts = {"opt1": True, "opt2": True, "opt3": True} + analysis = FakeAnalysis() + analysis.set_options(**opts) + analysis.run(ExperimentData(), **run_opts) + self.assertEqual(analysis.options.__dict__, opts) diff --git a/test/test_t1.py b/test/test_t1.py index ae79e6a87e..e05ed0190d 100644 --- a/test/test_t1.py +++ b/test/test_t1.py @@ -214,7 +214,7 @@ def test_t1_low_quality(self): self.assertEqual(result.quality, "bad") def test_experiment_config(self): - """Test converting to and from config works""" + """Test converting experiment to and from config works""" exp = T1(0, [1, 2, 3, 4, 5], unit="s") loaded_exp = T1.from_config(exp.config()) self.assertNotEqual(exp, loaded_exp) @@ -224,3 +224,10 @@ def test_roundtrip_serializable(self): """Test round trip JSON serialization""" exp = T1(0, [1, 2, 3, 4, 5], unit="s") self.assertRoundTripSerializable(exp, self.experiments_equiv) + + def test_analysis_config(self): + """ "Test converting analysis to and from config works""" + analysis = T1Analysis() + loaded = T1Analysis.from_config(analysis.config()) + self.assertNotEqual(analysis, loaded) + self.assertEqual(analysis.config(), loaded.config()) diff --git a/test/test_t2ramsey.py b/test/test_t2ramsey.py index 373eb8eb1b..be764e3a2f 100644 --- a/test/test_t2ramsey.py +++ b/test/test_t2ramsey.py @@ -18,6 +18,7 @@ from qiskit.utils import apply_prefix from qiskit_experiments.framework import ParallelExperiment from qiskit_experiments.library import T2Ramsey +from qiskit_experiments.library.characterization import T2RamseyAnalysis from qiskit_experiments.test.t2ramsey_backend import T2RamseyBackend @@ -210,3 +211,10 @@ def test_roundtrip_serializable(self): """Test round trip JSON serialization""" exp = T2Ramsey(0, [1, 2, 3, 4, 5], unit="s") self.assertRoundTripSerializable(exp, self.experiments_equiv) + + def test_analysis_config(self): + """ "Test converting analysis to and from config works""" + analysis = T2RamseyAnalysis() + loaded = T2RamseyAnalysis.from_config(analysis.config()) + self.assertNotEqual(analysis, loaded) + self.assertEqual(analysis.config(), loaded.config()) diff --git a/test/test_tomography.py b/test/test_tomography.py index 4f16f66b2f..a903d02c7e 100644 --- a/test/test_tomography.py +++ b/test/test_tomography.py @@ -21,6 +21,7 @@ from qiskit.providers.aer import AerSimulator from qiskit_experiments.framework import BatchExperiment, ParallelExperiment from qiskit_experiments.library import StateTomography, ProcessTomography +from qiskit_experiments.library.tomography import StateTomographyAnalysis, ProcessTomographyAnalysis # TODO: tests for CVXPY fitters @@ -273,6 +274,13 @@ def test_experiment_config(self): self.assertNotEqual(exp, loaded_exp) self.assertTrue(self.experiments_equiv(exp, loaded_exp)) + def test_analysis_config(self): + """ "Test converting analysis to and from config works""" + analysis = StateTomographyAnalysis() + loaded = StateTomographyAnalysis.from_config(analysis.config()) + self.assertNotEqual(analysis, loaded) + self.assertEqual(analysis.config(), loaded.config()) + @ddt.ddt class TestProcessTomography(QiskitExperimentsTestCase): @@ -485,6 +493,13 @@ def test_experiment_config(self): self.assertNotEqual(exp, loaded_exp) self.assertTrue(self.experiments_equiv(exp, loaded_exp)) + def test_analysis_config(self): + """ "Test converting analysis to and from config works""" + analysis = ProcessTomographyAnalysis() + loaded = ProcessTomographyAnalysis.from_config(analysis.config()) + self.assertNotEqual(analysis, loaded) + self.assertEqual(analysis.config(), loaded.config()) + def teleport_circuit(): """Teleport qubit 0 to qubit 2"""