Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
397 changes: 397 additions & 0 deletions qiskit_experiments/curve_analysis/__init__.py

Large diffs are not rendered by default.

659 changes: 282 additions & 377 deletions qiskit_experiments/curve_analysis/curve_analysis.py

Large diffs are not rendered by default.

154 changes: 154 additions & 0 deletions qiskit_experiments/curve_analysis/curve_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import dataclasses
import inspect
from typing import Any, Dict, Callable, Union, List, Tuple, Optional, Iterable

import numpy as np
Expand Down Expand Up @@ -47,6 +48,149 @@ class SeriesDef:
# Index of canvas if the result figure is multi-panel
canvas: Optional[int] = None

# Automatically extracted signature of the fit function
signature: List[str] = dataclasses.field(init=False)

# Name of group. Curves in the same group are simultaneously fit.
group: Optional[str] = "default"

def __post_init__(self):
"""Implicitly parse fit function signature for fit function."""
# The first argument is x, which is not a fit parameter
sig = list(inspect.signature(self.fit_func).parameters.keys())[1:]
# Note that this dataclass is frozen
object.__setattr__(self, "signature", sig)


class CompositeFitFunction:
"""Function-like object that is generated by a curve analysis subclass.

This is function-like object that implements a fit model as a ``__call__`` magic method,
thus it behaves as if a python function that the SciPy curve_fit solver accepts.
Note that the fit function there only accepts variadic arguments.

This class ties together the fit function and associated parameter names to
perform correct parameter mapping among multiple objective functions with different signature,
in which some parameters may be excluded from the fitting when they are fixed.
"""

def __init__(
self,
group: str,
fit_functions: [List[Callable]],
signatures: List[List[str]],
curve_inds: List[int],
fixed_parameters: Optional[List[str]] = None,
**metadata,
):
"""Create new composite function.

Args:
group: A name of the fit group that this function belongs to.
fit_functions: List of callable that defines fit function of a single series.
signatures: List of parameter names of a single series.
curve_inds: List of index corresponding to the curve data.
fixed_parameters: List of parameter names that are fixed in the fit.
**metadata: Arbitrary dictionary with information of this fit function.

Raises:
AnalysisError: When ``fit_functions`` and ``signatures`` don't match.
"""
if len(fit_functions) != len(signatures):
raise AnalysisError("Different numbers of fit_functions and signatures are given.")

if fixed_parameters is None:
fixed_parameters = tuple()

self._group = group
self._fit_functions = fit_functions
self._signatures = signatures
self._curve_inds = curve_inds
self._metadata = metadata or dict()

# Parameters that can be overridden
self._fixed_params = {p: None for p in fixed_parameters}
self._data_index = None

fit_args = []
# Logic is not efficient but should keep order of parameters for backward compatibility
for signature in signatures:
for param in signature:
if param not in fit_args and param not in fixed_parameters:
fit_args.append(param)

self._full_params = fit_args

def __call__(self, x: np.ndarray, *params) -> np.ndarray:
"""Called by the scipy fit function.

Args:
x: Composite X values array.
*params: Variadic argument of fitting parameters.

Returns:
Computed Y values array.
"""
kwparams = dict(zip(self._full_params, params))
kwparams.update(self._fixed_params)

y = np.zeros(x.size)
for i, func, sig in zip(self._curve_inds, self._fit_functions, self._signatures):
if self._data_index is not None:
inds = self._data_index == i
else:
# Use all data if data index is not set
inds = np.full(x.size, True, dtype=bool)

y[inds] = func(x[inds], **{p: kwparams[p] for p in sig})

return y

def bind_parameters(self, **kwparams):
"""Set fixed parameters."""
bind_dict = {k: kwparams[k] for k in self._fixed_params.keys() if k in kwparams}
self._fixed_params.update(bind_dict)

@property
def data_index(self) -> np.ndarray:
"""Return current data index mapping."""
return self._data_index

@data_index.setter
def data_index(self, new_indices: np.ndarray):
"""Set data index mapping for current fit data."""
self._data_index = new_indices

@property
def signature(self) -> List[str]:
"""Return signature of the composite fit function."""
return self._full_params

@property
def metadata(self) -> Dict[str, Any]:
"""Return metadata of this fit function."""
return self._metadata

@property
def group(self) -> str:
"""Return a group that this function belongs to."""
return self._group

def copy(self):
"""Return copy of this function. Assigned parameters and indices are refleshed."""
return CompositeFitFunction(
group=self.group,
fit_functions=self._fit_functions,
signatures=self._signatures,
curve_inds=self._curve_inds,
fixed_parameters=list(self._fixed_params.keys()),
**self.metadata.copy(),
)

def __repr__(self):
sigrepr = ", ".join(self.signature)
return f"{self.__class__.__name__}(x, {sigrepr}; group={self.group})"


@dataclasses.dataclass(frozen=True)
class CurveData:
Expand Down Expand Up @@ -99,6 +243,12 @@ class FitData:
# Y data range
y_range: Tuple[float, float]

# String representation of fit model
fit_mdoel: str = "not defined"

# String representation of the group that this fit belongs to.
group: str = "default"

def fitval(self, key: str) -> uncertainties.UFloat:
"""A helper method to get fit value object from parameter key name.

Expand Down Expand Up @@ -278,11 +428,14 @@ class FitOptions:

def __init__(
self,
group: str,
parameters: List[str],
default_p0: Optional[Union[Iterable[float], Dict[str, float]]] = None,
default_bounds: Optional[Union[Iterable[Tuple], Dict[str, Tuple]]] = None,
**extra,
):
self.group = group

# These are private members so that user cannot directly override values
# without implicitly implemented validation logic. No setter will be provided.
self.__p0 = InitialGuesses(parameters, default_p0)
Expand All @@ -309,6 +462,7 @@ def add_extra_options(self, **kwargs):
def copy(self):
"""Create copy of this option."""
return FitOptions(
group=self.group,
parameters=list(self.__p0.keys()),
default_p0=dict(self.__p0),
default_bounds=dict(self.__bounds),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""

from collections import defaultdict
import functools
from typing import List, Dict, Optional

import uncertainties
Expand All @@ -46,7 +47,8 @@ def draw(
raw_samples: List[CurveData],
fit_samples: List[CurveData],
tick_labels: Dict[str, str],
fit_data: FitData,
fit_data: List[FitData],
fix_parameters: Dict[str, float],
result_entries: List[AnalysisResultData],
style: Optional[PlotterStyle] = None,
axis: Optional["matplotlib.axes.Axes"] = None,
Expand All @@ -60,6 +62,7 @@ def draw(
tick_labels: Dictionary of axis label information. Axis units and label for x and y
value should be explained.
fit_data: fit data generated by the analysis.
fix_parameters: parameter not being in fitting.
result_entries: List of analysis result data entries.
style: Optional. A configuration object to modify the appearance of the figure.
axis: Optional. A matplotlib Axis object.
Expand All @@ -84,6 +87,7 @@ def draw(
raw_sample=raw_samp,
fit_sample=fit_samp,
fit_data=fit_data,
fix_parameters=fix_parameters,
style=style,
)

Expand Down Expand Up @@ -116,11 +120,20 @@ def draw(
# write analysis report
if fit_data:
report_str = write_fit_report(result_entries)
report_str += r"Fit $\chi^2$ = " + f"{fit_data.reduced_chisq: .4g}"

if len(fit_data) > 2:
chisq_strs = []
for fit_datum in fit_data:
chisq_strs.append(
r"Fit $\chi^2$ = " + f"{fit_datum.reduced_chisq: .4g} ({fit_datum.group})"
)
report_str += "\n".join(chisq_strs)
else:
report_str += r"Fit $\chi^2$ = " + f"{fit_data[0].reduced_chisq: .4g}"

report_handler = axis.text(
*style.fit_report_rpos,
report_str,
s=report_str,
ha="center",
va="top",
size=style.fit_report_text_size,
Expand All @@ -146,7 +159,8 @@ def draw(
raw_samples: List[CurveData],
fit_samples: List[CurveData],
tick_labels: Dict[str, str],
fit_data: FitData,
fit_data: List[FitData],
fix_parameters: Dict[str, float],
result_entries: List[AnalysisResultData],
style: Optional[PlotterStyle] = None,
axis: Optional["matplotlib.axes.Axes"] = None,
Expand All @@ -160,6 +174,7 @@ def draw(
tick_labels: Dictionary of axis label information. Axis units and label for x and y
value should be explained.
fit_data: fit data generated by the analysis.
fix_parameters: parameter not being in fitting.
result_entries: List of analysis result data entries.
style: Optional. A configuration object to modify the appearance of the figure.
axis: Optional. A matplotlib Axis object.
Expand Down Expand Up @@ -222,6 +237,7 @@ def draw(
raw_sample=raw_samples[curve_ind],
fit_sample=fit_samples[curve_ind],
fit_data=fit_data,
fix_parameters=fix_parameters,
style=style,
)

Expand Down Expand Up @@ -268,11 +284,20 @@ def draw(
# write analysis report
if fit_data:
report_str = write_fit_report(result_entries)
report_str += r"Fit $\chi^2$ = " + f"{fit_data.reduced_chisq: .4g}"

if len(fit_data) > 2:
chisq_strs = []
for fit_datum in fit_data:
chisq_strs.append(
r"Fit $\chi^2$ = " + f"{fit_datum.reduced_chisq: .4g} ({fit_datum.group})"
)
report_str += "\n".join(chisq_strs)
else:
report_str += r"Fit $\chi^2$ = " + f"{fit_data[0].reduced_chisq: .4g}"

report_handler = axis.text(
*style.fit_report_rpos,
report_str,
s=report_str,
ha="center",
va="top",
size=style.fit_report_text_size,
Expand All @@ -293,7 +318,8 @@ def draw_single_curve_mpl(
series_def: SeriesDef,
raw_sample: CurveData,
fit_sample: CurveData,
fit_data: FitData,
fit_data: List[FitData],
fix_parameters: Dict[str, float],
style: PlotterStyle,
):
"""A function that draws a single curve on the given plotter canvas.
Expand All @@ -304,6 +330,7 @@ def draw_single_curve_mpl(
raw_sample: Raw sample data.
fit_sample: Formatted sample data.
fit_data: Fitting parameter collection.
fix_parameters: Parameters not being in the fitting.
style: Style sheet for plotting.
"""

Expand All @@ -330,15 +357,21 @@ def draw_single_curve_mpl(
)

# plot fit curve
if fit_data:
plot_curve_fit(
func=series_def.fit_func,
result=fit_data,
ax=axis,
color=series_def.plot_color,
zorder=2,
fit_uncertainty=style.plot_sigma,
)
for fit_datum in fit_data:
if series_def.group == fit_datum.group:
if fix_parameters:
fit_func = functools.partial(series_def.fit_func, **fix_parameters)
else:
fit_func = series_def.fit_func

plot_curve_fit(
func=fit_func,
result=fit_datum,
ax=axis,
color=series_def.plot_color,
zorder=2,
fit_uncertainty=style.plot_sigma,
)


def write_fit_report(result_entries: List[AnalysisResultData]) -> str:
Expand Down
Loading