Skip to content

Commit

Permalink
Improve debugging info in ReturnData (#2349)
Browse files Browse the repository at this point in the history
* Improve debugging info in ReturnData

  * Add `ReturnData::t_last` that holds the last solver timepoint (Closes #2246)
  * `ReturnData::J` is now the evaluted at `t_last` instead of the last successfully reached output timepoint (Closes #2334)
  * `ReturnData::xdot` is now the evaluted at `t_last` instead of the last successfully reached output timepoint

* skip xdot/J checks in case of failures
  • Loading branch information
dweindl authored Mar 7, 2024
1 parent 2dfd2bb commit 8d65524
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 14 deletions.
2 changes: 1 addition & 1 deletion include/amici/forwardproblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ class ForwardProblem {
* @brief Creates a carbon copy of the current simulation state variables
* @return state
*/
SimulationState getSimulationState() const;
SimulationState getSimulationState();

/** array of index vectors (dimension ne) indicating whether the respective
* root has been detected for all so far encountered discontinuities,
Expand Down
7 changes: 5 additions & 2 deletions include/amici/rdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ class ReturnData : public ModelDimensions {
*/
std::vector<realtype> ts;

/** time derivative (shape `nx`) */
/** time derivative (shape `nx`) evaluated at `t_last`. */
std::vector<realtype> xdot;

/**
* Jacobian of differential equation right hand side (shape `nx` x `nx`,
* row-major)
* row-major) evaluated at `t_last`.
*/
std::vector<realtype> J;

Expand Down Expand Up @@ -456,6 +456,9 @@ class ReturnData : public ModelDimensions {
/** log messages */
std::vector<LogItem> messages;

/** The final internal time of the solver. */
realtype t_last{std::numeric_limits<realtype>::quiet_NaN()};

protected:
/** offset for sigma_residuals */
realtype sigma_offset;
Expand Down
1 change: 1 addition & 0 deletions python/sdist/amici/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class ReturnDataView(SwigPtrView):
"cpu_timeB",
"cpu_time_total",
"messages",
"t_last",
]

def __init__(self, rdata: Union[ReturnDataPtr, ReturnData]):
Expand Down
2 changes: 2 additions & 0 deletions src/amici.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ std::unique_ptr<ReturnData> runAmiciSimulation(
);
}

rdata->t_last = solver.gett();

rdata->cpu_time_total = cpu_timer.elapsed_milliseconds();

// verify that reported CPU times are plausible
Expand Down
5 changes: 4 additions & 1 deletion src/forwardproblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,10 @@ void ForwardProblem::getAdjointUpdates(Model& model, ExpData const& edata) {
}
}

SimulationState ForwardProblem::getSimulationState() const {
SimulationState ForwardProblem::getSimulationState() {
if (std::isfinite(solver->gett())) {
solver->writeSolution(&t_, x_, dx_, sx_, dx_);
}
auto state = SimulationState();
state.t = t_;
state.x = x_;
Expand Down
4 changes: 4 additions & 0 deletions src/hdf5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,10 @@ void writeReturnDataDiagnosis(
&rdata.cpu_time_total, 1
);

H5LTset_attribute_double(
file.getId(), hdf5Location.c_str(), "t_last", &rdata.t_last, 1
);

if (!rdata.J.empty())
createAndWriteDouble2DDataset(
file, hdf5Location + "/J", rdata.J, rdata.nx, rdata.nx
Expand Down
23 changes: 13 additions & 10 deletions tests/cpp/testfunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,6 @@ void verifyReturnData(std::string const& hdffile, std::string const& resultPath,

// CHECK_EQUAL(AMICI_O2MODE_FULL, udata->o2mode);

if(hdf5::locationExists(file, resultPath + "/diagnosis/J")) {
expected = hdf5::getDoubleDataset2D(file, resultPath + "/diagnosis/J", m, n);
checkEqualArray(expected, rdata->J, atol, rtol, "J");
} else {
ASSERT_TRUE(rdata->J.empty());
}

if(hdf5::locationExists(file, resultPath + "/y")) {
expected = hdf5::getDoubleDataset2D(file, resultPath + "/y", m, n);
checkEqualArray(expected, rdata->y, atol, rtol, "y");
Expand Down Expand Up @@ -234,12 +227,22 @@ void verifyReturnData(std::string const& hdffile, std::string const& resultPath,
ASSERT_TRUE(rdata->sigmaz.empty());
}

expected = hdf5::getDoubleDataset1D(file, resultPath + "/diagnosis/xdot");
checkEqualArray(expected, rdata->xdot, atol, rtol, "xdot");

expected = hdf5::getDoubleDataset1D(file, resultPath + "/x0");
checkEqualArray(expected, rdata->x0, atol, rtol, "x0");

if(rdata->status == AMICI_SUCCESS) {
// for the failed cases, the stored results don't match
// since https://github.com/AMICI-dev/AMICI/pull/2349
expected = hdf5::getDoubleDataset1D(file, resultPath + "/diagnosis/xdot");
checkEqualArray(expected, rdata->xdot, atol, rtol, "xdot");

if(hdf5::locationExists(file, resultPath + "/diagnosis/J")) {
expected = hdf5::getDoubleDataset2D(file, resultPath + "/diagnosis/J", m, n);
checkEqualArray(expected, rdata->J, atol, rtol, "J");
} else {
ASSERT_TRUE(rdata->J.empty());
}
}
if(rdata->sensi >= SensitivityOrder::first) {
verifyReturnDataSensitivities(file, resultPath, rdata, model, atol, rtol);
} else {
Expand Down

0 comments on commit 8d65524

Please sign in to comment.