Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 27 additions & 99 deletions qiskit_experiments/curve_analysis/curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,6 @@ def __init__(self):
#: Backend: backend object used for experimentation
self.__backend = None

# Add expected options to instance variable so that every method can access to.
for key in self._default_options().__dict__:
setattr(self, f"__{key}", None)

# Add fixed parameters to instance variable so that every method can access to.
for key in self.__fixed_parameters__:
setattr(self, f"__{key}", None)

@classmethod
def _fit_params(cls) -> List[str]:
"""Return a list of fitting parameters.
Expand Down Expand Up @@ -346,6 +338,8 @@ def _default_options(cls) -> Options:
that contains a set of configurations to create a fit plot.
extra (Dict[str, Any]): A dictionary that is appended to all database entries
as extra information.
curve_fitter_options (Dict[str, Any]) Options that are passed to the
specified curve fitting function.
"""
options = super()._default_options()

Expand All @@ -366,6 +360,7 @@ def _default_options(cls) -> Options:
options.curve_plotter = "mpl_single_canvas"
options.style = PlotterStyle()
options.extra = dict()
options.curve_fitter_options = dict()

# automatically populate initial guess and boundary
fit_params = cls._fit_params()
Expand Down Expand Up @@ -577,7 +572,7 @@ def _is_target_series(datum, **filters):
# Extract X, Y, Y_sigma data
data = experiment_data.data()

x_key = self._get_option("x_key")
x_key = self.options.x_key
try:
x_values = np.asarray([datum["metadata"][x_key] for datum in data], dtype=float)
except KeyError as ex:
Expand Down Expand Up @@ -675,21 +670,6 @@ def _experiment_options(self, index: int = -1) -> Dict[str, Any]:
# Ignore experiment metadata or job metadata is not set or key is not found
return None

def _analysis_options(self, index: int = -1) -> Dict[str, Any]:
"""Returns the analysis options of given job index.

Args:
index: Index of job metadata to extract. Default to -1 (latest).

Returns:
Analysis options. This option is used for analysis.
"""
try:
return self.__experiment_metadata["job_metadata"][index]["analysis_options"]
except (TypeError, KeyError, IndexError):
# Ignore experiment metadata or job metadata is not set or key is not found
return None

def _run_options(self, index: int = -1) -> Dict[str, Any]:
"""Returns the run options of given job index.

Expand Down Expand Up @@ -772,67 +752,17 @@ def _data(

raise AnalysisError(f"Specified series {series_name} is not defined in this analysis.")

def _arg_parse(self, **options) -> Dict[str, Any]:
"""Parse input kwargs with predicted input.

Class attributes will be updated according to the ``options``.
For example, if ``options`` has a key ``p0``, and the class
has an attribute named ``__p0``, then the attribute ``__0p``
will be updated to ``options["p0"]``.

Options that don't have matching attributes will be included
in the returned dictionary.

Args:
options: User-input keyword argument options.

Returns:
Keyword arguments not specified in the default options
of the class.
"""
extra_options = dict()
for key, value in options.items():
private_key = f"__{key}"
if hasattr(self, private_key):
setattr(self, private_key, value)
else:
extra_options[key] = value

return extra_options

def _get_option(self, arg_name: str) -> Any:
"""A helper function to get specified field from the input analysis options.

Args:
arg_name: Name of option.

Return:
Arbitrary object specified by the option name.

Raises:
AnalysisError:
- When `arg_name` is not found in the analysis options.
"""
try:
return getattr(self, f"__{arg_name}")
except AttributeError as ex:
raise AnalysisError(
f"The argument {arg_name} is selected but not defined. "
"This key-value pair should be defined in the analysis option."
) from ex

def _run_analysis(
self, experiment_data: ExperimentData
) -> Tuple[List[AnalysisResultData], List["pyplot.Figure"]]:
#
# 1. Parse arguments
#
extra_options = self._arg_parse(**self.options.__dict__)

# Update all fit functions in the series definitions if fixed parameter is defined.
# Fixed parameters should be provided by the analysis options.
if self.__fixed_parameters__:
assigned_params = {k: self._get_option(k) for k in self.__fixed_parameters__}
assigned_params = {k: self.options.get(k, None) for k in self.__fixed_parameters__}

# Check if all parameters are assigned.
if any(v is None for v in assigned_params.values()):
Expand Down Expand Up @@ -870,7 +800,7 @@ def _run_analysis(

# No data processor has been provided at run-time we infer one from the job
# metadata and default to the data processor for averaged classified data.
data_processor = self._get_option("data_processor")
data_processor = self.options.data_processor

if not data_processor:
run_options = self._run_options() or dict()
Expand All @@ -883,9 +813,8 @@ def _run_analysis(
) from ex

meas_return = run_options.get("meas_return", None)
normalization = self._get_option("normalization")

data_processor = get_processor(meas_level, meas_return, normalization)
data_processor = get_processor(meas_level, meas_return, self.options.normalization)

if isinstance(data_processor, DataProcessor) and not data_processor.is_trained:
# Qiskit DataProcessor instance. May need calibration.
Expand All @@ -899,15 +828,14 @@ def _run_analysis(
#
# 4. Run fitting
#
curve_fitter = self._get_option("curve_fitter")
formatted_data = self._data(label="fit_ready")

# Generate algorithmic initial guesses and boundaries
default_fit_opt = FitOptions(
parameters=self._fit_params(),
default_p0=self._get_option("p0"),
default_bounds=self._get_option("bounds"),
**extra_options,
default_p0=self.options.p0,
default_bounds=self.options.bounds,
**self.options.curve_fitter_options,
)

fit_options = self._generate_fit_guesses(default_fit_opt)
Expand All @@ -918,7 +846,7 @@ def _run_analysis(
fit_results = []
for fit_opt in set(fit_options):
try:
fit_result = curve_fitter(
fit_result = self.options.curve_fitter(
funcs=[series_def.fit_func for series_def in self.__series__],
series=formatted_data.data_index,
xdata=formatted_data.x,
Expand Down Expand Up @@ -969,13 +897,13 @@ def _run_analysis(
"dof": fit_result.dof,
"covariance_mat": fit_result.pcov,
"fit_models": fit_models,
**self._get_option("extra"),
**self.options.extra,
},
)
)

# output special parameters
result_parameters = self._get_option("result_parameters")
result_parameters = self.options.result_parameters
if result_parameters:
for param_repr in result_parameters:
if isinstance(param_repr, ParameterRepr):
Expand All @@ -991,14 +919,14 @@ def _run_analysis(
value=fit_result.fitval(p_name, unit),
chisq=fit_result.reduced_chisq,
quality=quality,
extra=self._get_option("extra"),
extra=self.options.extra,
)
analysis_results.append(result_entry)

# add extra database entries
analysis_results.extend(self._extra_database_entry(fit_result))

if self._get_option("return_data_points"):
if self.options.return_data_points:
# save raw data points in the data base if option is set (default to false)
raw_data_dict = dict()
for series_def in self.__series__:
Expand All @@ -1012,32 +940,32 @@ def _run_analysis(
name=DATA_ENTRY_PREFIX + self.__class__.__name__,
value=raw_data_dict,
extra={
"x-unit": self._get_option("xval_unit"),
"y-unit": self._get_option("yval_unit"),
"x-unit": self.options.xval_unit,
"y-unit": self.options.yval_unit,
},
)
analysis_results.append(raw_data_entry)

#
# 6. Create figures
#
if self._get_option("plot"):
fit_figure = FitResultPlotters[self._get_option("curve_plotter")].value.draw(
if self.options.plot:
fit_figure = FitResultPlotters[self.options.curve_plotter].value.draw(
series_defs=self.__series__,
raw_samples=[self._data(ser.name, "raw_data") for ser in self.__series__],
fit_samples=[self._data(ser.name, "fit_ready") for ser in self.__series__],
tick_labels={
"xval_unit": self._get_option("xval_unit"),
"yval_unit": self._get_option("yval_unit"),
"xlabel": self._get_option("xlabel"),
"ylabel": self._get_option("ylabel"),
"xlim": self._get_option("xlim"),
"ylim": self._get_option("ylim"),
"xval_unit": self.options.xval_unit,
"yval_unit": self.options.yval_unit,
"xlabel": self.options.xlabel,
"ylabel": self.options.ylabel,
"xlim": self.options.xlim,
"ylim": self.options.ylim,
},
fit_data=fit_result,
result_entries=analysis_results,
style=self._get_option("style"),
axis=self._get_option("axis"),
style=self.options.style,
axis=self.options.axis,
)
figures = [fit_figure]
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ def _generate_fit_guesses(
user_opt.bounds.set_if_empty(amp=(0, 2 * max_abs_y))

# Base the initial guess on the intended angle_per_gate and phase offset.
apg = self._get_option("angle_per_gate")
phi = self._get_option("phase_offset")
apg = self.options.angle_per_gate
phi = self.options.phase_offset

# Prepare logical guess for specific condition (often satisfied)
d_theta_guesses = []

offsets = apg * curve_data.x + phi
amp = user_opt.p0.get("amp", self._get_option("amp"))
amp = user_opt.p0.get("amp", self.options.amp)
for i in range(curve_data.x.size):
xi = curve_data.x[i]
yi = curve_data.y[i]
Expand Down Expand Up @@ -192,11 +192,10 @@ def _evaluate_quality(self, fit_data: curve.FitData) -> Union[str, None]:
This quantity is set in the analysis options.
"""
fit_d_theta = fit_data.fitval("d_theta").value
max_good_angle_error = self._get_option("max_good_angle_error")

criteria = [
fit_data.reduced_chisq < 3,
abs(fit_d_theta) < abs(max_good_angle_error),
abs(fit_d_theta) < abs(self.options.max_good_angle_error),
]

if all(criteria):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,16 @@ def _extra_database_entry(self, fit_data: curve.FitData) -> List[AnalysisResultD
)

# Calculate EPG
if not self._get_option("gate_error_ratio"):
if not self.options.gate_error_ratio:
# we attempt to get the ratio from the backend properties
if not self._get_option("error_dict"):
if not self.options.error_dict:
gate_error_ratio = RBUtils.get_error_dict_from_backend(
backend=self._backend, qubits=self._physical_qubits
)
else:
gate_error_ratio = self._get_option("error_dict")
gate_error_ratio = self.options.error_dict
else:
gate_error_ratio = self._get_option("gate_error_ratio")
gate_error_ratio = self.options.gate_error_ratio

count_ops = []
for meta in self._data(label="raw_data").metadata:
Expand All @@ -208,7 +208,7 @@ def _extra_database_entry(self, fit_data: curve.FitData) -> List[AnalysisResultD
gates_per_clifford,
)
elif num_qubits == 2:
epg_1_qubit = self._get_option("epg_1_qubit")
epg_1_qubit = self.options.epg_1_qubit
epg = RBUtils.calculate_2q_epg(
epc,
self._physical_qubits,
Expand Down
16 changes: 2 additions & 14 deletions test/curve_analysis/test_curve_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,9 @@ def test_cannot_create_invalid_fixed_parameter(self):
fixed_params=["not_existing_parameter"], # this parameter is not defined
)

def test_arg_parse_and_get_option(self):
"""Test if option parsing works correctly."""
user_option = {"x_key": "test_value", "test_key1": "value1", "test_key2": "value2"}

# argument not defined in default option should be returned as extra option
extra_option = self.analysis._arg_parse(**user_option)
ref_option = {"test_key1": "value1", "test_key2": "value2"}
self.assertDictEqual(extra_option, ref_option)

# default option value is stored as class variable
self.assertEqual(self.analysis._get_option("x_key"), "test_value")

def test_data_extraction(self):
"""Test data extraction method."""
self.analysis._arg_parse(x_key="xval")
self.analysis.set_options(x_key="xval")

# data to analyze
test_data0 = simulate_output_data(
Expand Down Expand Up @@ -268,7 +256,7 @@ def test_get_subset(self):
def _processor(datum):
return datum["data"], datum["data"] * 2

self.analysis._arg_parse(x_key="xval")
self.analysis.set_options(x_key="xval")
self.analysis._extract_curves(expdata, data_processor=_processor)

filt_data = self.analysis._data(series_name="curve1")
Expand Down