Skip to content

Commit

Permalink
Optionally include measurements in plot_observable_trajectories (#2381
Browse files Browse the repository at this point in the history
)

If some `ExpData` is provided, `plot_observable_trajectories` will now also visualize the measurements.
  • Loading branch information
dweindl committed Mar 13, 2024
1 parent 032fb6c commit 41b4ce4
Showing 1 changed file with 36 additions and 10 deletions.
46 changes: 36 additions & 10 deletions python/sdist/amici/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import seaborn as sns
from matplotlib.axes import Axes

import amici
from . import Model, ReturnDataView
from .numpy import StrOrExpr, evaluate

Expand Down Expand Up @@ -66,10 +67,11 @@ def plot_state_trajectories(

for ix, label in zip(state_indices, labels):
ax.plot(rdata["t"], rdata["x"][:, ix], marker=marker, label=label)
ax.set_xlabel("$t$")
ax.set_ylabel("$x(t)$")
ax.legend()
ax.set_title("State trajectories")

ax.set_xlabel("$t$")
ax.set_ylabel("$x(t)$")
ax.legend()
ax.set_title("State trajectories")


def plot_observable_trajectories(
Expand All @@ -79,6 +81,7 @@ def plot_observable_trajectories(
model: Model = None,
prefer_names: bool = True,
marker=None,
edata: Union[amici.ExpData, amici.ExpDataView] = None,
) -> None:
"""
Plot observable trajectories.
Expand All @@ -97,8 +100,12 @@ def plot_observable_trajectories(
:param marker:
Point marker for plotting (see
`matplotlib documentation <https://matplotlib.org/stable/api/markers_api.html>`_).
:param edata:
Experimental data to be plotted (no event observables yet).
"""
if isinstance(edata, amici.amici.ExpData):
edata = amici.ExpDataView(edata)

if not ax:
fig, ax = plt.subplots()
if not observable_indices:
Expand All @@ -125,11 +132,30 @@ def plot_observable_trajectories(
labels = np.asarray(rdata.ptr.observable_ids)[list(observable_indices)]

for iy, label in zip(observable_indices, labels):
ax.plot(rdata["t"], rdata["y"][:, iy], marker=marker, label=label)
ax.set_xlabel("$t$")
ax.set_ylabel("$y(t)$")
ax.legend()
ax.set_title("Observable trajectories")
(l,) = ax.plot(
rdata["t"], rdata["y"][:, iy], marker=marker, label=label
)

if edata is not None:
ax.plot(
edata.ts,
edata.observedData[:, iy],
"x",
label=f"exp. {label}",
color=l.get_color(),
)
ax.errorbar(
edata.ts,
edata.observedData[:, iy],
yerr=rdata.sigmay[:, iy],
fmt="none",
color=l.get_color(),
)

ax.set_xlabel("$t$")
ax.set_ylabel("$y(t)$")
ax.set_title("Observable trajectories")
ax.legend()


def plot_jacobian(rdata: ReturnDataView):
Expand Down

0 comments on commit 41b4ce4

Please sign in to comment.