Skip to content

Commit

Permalink
Allow setting MaxConvFails and MaxNonlinIters on solvers (#2335)
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl authored Mar 4, 2024
1 parent 4782cad commit 239ce2c
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 4 deletions.
2 changes: 2 additions & 0 deletions include/amici/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ void serialize(Archive& ar, amici::Solver& s, unsigned int const /*version*/) {
ar & s.check_sensi_steadystate_conv_;
ar & s.rdata_mode_;
ar & s.maxtime_;
ar & s.max_conv_fails_;
ar & s.max_nonlin_iters_;
}

/**
Expand Down
47 changes: 47 additions & 0 deletions include/amici/solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,34 @@ class Solver {
check_sensi_steadystate_conv_ = flag;
}

/**
* @brief Set the maximum number of nonlinear solver iterations permitted
* per step.
* @param max_nonlin_iters maximum number of nonlinear solver iterations
*/
void setMaxNonlinIters(int max_nonlin_iters);

/**
* @brief Get the maximum number of nonlinear solver iterations permitted
* per step.
* @return maximum number of nonlinear solver iterations
*/
int getMaxNonlinIters() const;

/**
* @brief Set the maximum number of nonlinear solver convergence failures
* permitted per step.
* @param max_conv_fails maximum number of nonlinear solver convergence
*/
void setMaxConvFails(int max_conv_fails);

/**
* @brief Get the maximum number of nonlinear solver convergence failures
* permitted per step.
* @return maximum number of nonlinear solver convergence
*/
int getMaxConvFails() const;

/**
* @brief Serialize Solver (see boost::serialization::serialize)
* @param ar Archive to serialize to
Expand Down Expand Up @@ -1712,6 +1740,18 @@ class Solver {
SensitivityMethod const sensi_meth, bool preequilibration
) const;

/**
* @brief Apply the maximum number of nonlinear solver iterations permitted
* per step.
*/
virtual void apply_max_nonlin_iters() const = 0;

/**
* @brief Apply the maximum number of nonlinear solver convergence failures
* permitted per step.
*/
virtual void apply_max_conv_fails() const = 0;

/** state (dimension: nx_solver) */
mutable AmiVector x_{0};

Expand Down Expand Up @@ -1844,6 +1884,13 @@ class Solver {
*/
bool check_sensi_steadystate_conv_{true};

/** Maximum number of nonlinear solver iterations permitted per step */
int max_nonlin_iters_{3};

/** Maximum number of nonlinear solver convergence failures permitted per
* step */
int max_conv_fails_{10};

/** CPU time, forward solve */
mutable realtype cpu_time_{0.0};

Expand Down
4 changes: 4 additions & 0 deletions include/amici/solver_cvodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ class CVodeSolver : public Solver {
void setJacTimesVecFnB(int which) const override;

void setSparseJacFn_ss() const override;

void apply_max_nonlin_iters() const override;

void apply_max_conv_fails() const override;
};

} // namespace amici
Expand Down
4 changes: 4 additions & 0 deletions include/amici/solver_idas.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ class IDASolver : public Solver {
void setJacTimesVecFnB(int which) const override;

void setSparseJacFn_ss() const override;

void apply_max_nonlin_iters() const override;

void apply_max_conv_fails() const override;
};

} // namespace amici
Expand Down
22 changes: 22 additions & 0 deletions src/hdf5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,16 @@ void writeSolverSettingsToHDF5(
file.getId(), hdf5Location.c_str(), "check_sensi_steadystate_conv",
&ibuffer, 1
);

ibuffer = static_cast<int>(solver.getMaxNonlinIters());
H5LTset_attribute_int(
file.getId(), hdf5Location.c_str(), "max_nonlin_iters", &ibuffer, 1
);

ibuffer = static_cast<int>(solver.getMaxConvFails());
H5LTset_attribute_int(
file.getId(), hdf5Location.c_str(), "max_conv_fails", &ibuffer, 1
);
}

void readSolverSettingsFromHDF5(
Expand Down Expand Up @@ -1106,6 +1116,18 @@ void readSolverSettingsFromHDF5(
file, datasetPath, "check_sensi_steadystate_conv"
));
}

if (attributeExists(file, datasetPath, "max_nonlin_iters")) {
solver.setMaxNonlinIters(
getIntScalarAttribute(file, datasetPath, "max_nonlin_iters")
);
}

if (attributeExists(file, datasetPath, "max_conv_fails")) {
solver.setMaxConvFails(
getIntScalarAttribute(file, datasetPath, "max_conv_fails")
);
}
}

void readSolverSettingsFromHDF5(
Expand Down
4 changes: 1 addition & 3 deletions src/newton_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,7 @@ void NewtonSolverSparse::prepareLinearSystemB(
Model& model, SimulationState const& state
) {
/* Get sparse Jacobian */
model.fJSparseB(
state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_
);
model.fJSparseB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_);
Jtmp_.refresh();
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_);
if (status != SUNLS_SUCCESS)
Expand Down
27 changes: 26 additions & 1 deletion src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ Solver::Solver(Solver const& other)
, rdata_mode_(other.rdata_mode_)
, newton_step_steadystate_conv_(other.newton_step_steadystate_conv_)
, check_sensi_steadystate_conv_(other.check_sensi_steadystate_conv_)
, max_nonlin_iters_(other.max_nonlin_iters_)
, max_conv_fails_(other.max_conv_fails_)
, maxstepsB_(other.maxstepsB_)
, sensi_(other.sensi_) {}

Expand Down Expand Up @@ -191,6 +193,9 @@ void Solver::setup(
if (model->nt() > 1)
calcIC(model->getTimepoint(1));

apply_max_nonlin_iters();
apply_max_conv_fails();

cpu_time_ = 0.0;
cpu_timeB_ = 0.0;
}
Expand Down Expand Up @@ -553,7 +558,9 @@ bool operator==(Solver const& a, Solver const& b) {
== b.newton_step_steadystate_conv_)
&& (a.check_sensi_steadystate_conv_
== b.check_sensi_steadystate_conv_)
&& (a.rdata_mode_ == b.rdata_mode_);
&& (a.rdata_mode_ == b.rdata_mode_)
&& (a.max_conv_fails_ == b.max_conv_fails_)
&& (a.max_nonlin_iters_ == b.max_nonlin_iters_);
}

void Solver::applyTolerances() const {
Expand Down Expand Up @@ -666,6 +673,24 @@ void Solver::checkSensitivityMethod(
resetMutableMemory(nx(), nplist(), nquad());
}

void Solver::setMaxNonlinIters(int max_nonlin_iters) {
if (max_nonlin_iters < 0)
throw AmiException("max_nonlin_iters must be a non-negative number");

max_nonlin_iters_ = max_nonlin_iters;
}

int Solver::getMaxNonlinIters() const { return max_nonlin_iters_; }

void Solver::setMaxConvFails(int max_conv_fails) {
if (max_conv_fails < 0)
throw AmiException("max_conv_fails must be a non-negative number");

max_conv_fails_ = max_conv_fails;
}

int Solver::getMaxConvFails() const { return max_conv_fails_; }

int Solver::getNewtonMaxSteps() const { return newton_maxsteps_; }

void Solver::setNewtonMaxSteps(int const newton_maxsteps) {
Expand Down
13 changes: 13 additions & 0 deletions src/solver_cvodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,19 @@ void CVodeSolver::setSparseJacFn_ss() const {
throw CvodeException(status, "CVodeSetJacFn");
}

void CVodeSolver::apply_max_nonlin_iters() const {
int status
= CVodeSetMaxNonlinIters(solver_memory_.get(), getMaxNonlinIters());
if (status != CV_SUCCESS)
throw CvodeException(status, "CVodeSetMaxNonlinIters");
}

void CVodeSolver::apply_max_conv_fails() const {
int status = CVodeSetMaxConvFails(solver_memory_.get(), getMaxConvFails());
if (status != CV_SUCCESS)
throw CvodeException(status, "CVodeSetMaxConvFails");
}

Solver* CVodeSolver::clone() const { return new CVodeSolver(*this); }

void CVodeSolver::allocateSolver() const {
Expand Down
13 changes: 13 additions & 0 deletions src/solver_idas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,19 @@ void IDASolver::setSparseJacFn_ss() const {
throw IDAException(status, "IDASetJacFn");
}

void IDASolver::apply_max_nonlin_iters() const {
int status
= IDASetMaxNonlinIters(solver_memory_.get(), getMaxNonlinIters());
if (status != IDA_SUCCESS)
throw IDAException(status, "IDASetMaxNonlinIters");
}

void IDASolver::apply_max_conv_fails() const {
int status = IDASetMaxConvFails(solver_memory_.get(), getMaxConvFails());
if (status != IDA_SUCCESS)
throw IDAException(status, "IDASetMaxConvFails");
}

Solver* IDASolver::clone() const { return new IDASolver(*this); }

void IDASolver::allocateSolver() const {
Expand Down

0 comments on commit 239ce2c

Please sign in to comment.