From 41b4ce4c8c6340bfe599446732593054dd7358be Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Wed, 13 Mar 2024 19:22:38 +0100 Subject: [PATCH] Optionally include measurements in `plot_observable_trajectories` (#2381) If some `ExpData` is provided, `plot_observable_trajectories` will now also visualize the measurements. --- python/sdist/amici/plotting.py | 46 ++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/python/sdist/amici/plotting.py b/python/sdist/amici/plotting.py index 25607638d7..d27f2994ce 100644 --- a/python/sdist/amici/plotting.py +++ b/python/sdist/amici/plotting.py @@ -12,6 +12,7 @@ import seaborn as sns from matplotlib.axes import Axes +import amici from . import Model, ReturnDataView from .numpy import StrOrExpr, evaluate @@ -66,10 +67,11 @@ def plot_state_trajectories( for ix, label in zip(state_indices, labels): ax.plot(rdata["t"], rdata["x"][:, ix], marker=marker, label=label) - ax.set_xlabel("$t$") - ax.set_ylabel("$x(t)$") - ax.legend() - ax.set_title("State trajectories") + + ax.set_xlabel("$t$") + ax.set_ylabel("$x(t)$") + ax.legend() + ax.set_title("State trajectories") def plot_observable_trajectories( @@ -79,6 +81,7 @@ def plot_observable_trajectories( model: Model = None, prefer_names: bool = True, marker=None, + edata: Union[amici.ExpData, amici.ExpDataView] = None, ) -> None: """ Plot observable trajectories. @@ -97,8 +100,12 @@ def plot_observable_trajectories( :param marker: Point marker for plotting (see `matplotlib documentation `_). - + :param edata: + Experimental data to be plotted (no event observables yet). """ + if isinstance(edata, amici.amici.ExpData): + edata = amici.ExpDataView(edata) + if not ax: fig, ax = plt.subplots() if not observable_indices: @@ -125,11 +132,30 @@ def plot_observable_trajectories( labels = np.asarray(rdata.ptr.observable_ids)[list(observable_indices)] for iy, label in zip(observable_indices, labels): - ax.plot(rdata["t"], rdata["y"][:, iy], marker=marker, label=label) - ax.set_xlabel("$t$") - ax.set_ylabel("$y(t)$") - ax.legend() - ax.set_title("Observable trajectories") + (l,) = ax.plot( + rdata["t"], rdata["y"][:, iy], marker=marker, label=label + ) + + if edata is not None: + ax.plot( + edata.ts, + edata.observedData[:, iy], + "x", + label=f"exp. {label}", + color=l.get_color(), + ) + ax.errorbar( + edata.ts, + edata.observedData[:, iy], + yerr=rdata.sigmay[:, iy], + fmt="none", + color=l.get_color(), + ) + + ax.set_xlabel("$t$") + ax.set_ylabel("$y(t)$") + ax.set_title("Observable trajectories") + ax.legend() def plot_jacobian(rdata: ReturnDataView):