Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 42 additions & 28 deletions src/optimagic/optimization/internal_optimization_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from copy import copy
from dataclasses import asdict, dataclass, replace
from typing import Any, Callable, cast
from typing import Any, Callable, Literal, cast

import numpy as np
from numpy.typing import NDArray
Expand All @@ -11,10 +11,7 @@
from optimagic.batch_evaluators import process_batch_evaluator
from optimagic.differentiation.derivatives import first_derivative
from optimagic.differentiation.numdiff_options import NumdiffOptions
from optimagic.exceptions import (
UserFunctionRuntimeError,
get_traceback,
)
from optimagic.exceptions import UserFunctionRuntimeError, get_traceback
from optimagic.logging.logger import LogStore
from optimagic.logging.types import IterationState
from optimagic.optimization.fun_value import (
Expand Down Expand Up @@ -474,7 +471,7 @@ def _pure_evaluate_jac(
out_jac = _process_jac_value(
value=jac_value, direction=self._direction, converter=self._converter, x=x
)
self._assert_finite_jac(out_jac, jac_value, params)
_assert_finite_jac(out_jac, jac_value, params, "jac")

stop_time = time.perf_counter()

Expand Down Expand Up @@ -548,7 +545,7 @@ def func(x: NDArray[np.float64]) -> SpecificFunctionValue:
warnings.warn(msg)
fun_value, jac_value = self._error_penalty_func(x)

self._assert_finite_jac(jac_value, jac_value, params)
_assert_finite_jac(jac_value, jac_value, params, "numerical")

algo_fun_value, hist_fun_value = _process_fun_value(
value=fun_value, # type: ignore
Expand Down Expand Up @@ -689,7 +686,7 @@ def _pure_evaluate_fun_and_jac(
if self._direction == Direction.MAXIMIZE:
out_jac = -out_jac

self._assert_finite_jac(out_jac, jac_value, params)
_assert_finite_jac(out_jac, jac_value, params, "fun_and_jac")

stop_time = time.perf_counter()

Expand All @@ -713,31 +710,48 @@ def _pure_evaluate_fun_and_jac(

return (algo_fun_value, out_jac), hist_entry, log_entry

def _assert_finite_jac(
self, out_jac: NDArray[np.float64], jac_value: PyTree, params: PyTree
) -> None:
"""Check for infinite and NaN values in the jacobian and raise an error if
found.

Args:
out_jac: internal processed jacobian to check for infinities.
jac_value: original jacobian value as returned by the user function,
included in error messages for debugging.
params: user-facing parameter representation at evaluation point.
def _assert_finite_jac(
out_jac: NDArray[np.float64],
jac_value: PyTree,
params: PyTree,
origin: Literal["numerical", "jac", "fun_and_jac"],
) -> None:
"""Check for infinite and NaN values in the Jacobian and raise an error if found.

Raises:
UserFunctionRuntimeError: If any infinite values are found in the jacobian.
Args:
out_jac: internal processed Jacobian to check for finiteness.
jac_value: original Jacobian value as returned by the user function,
params: user-facing parameter representation at evaluation point.

"""
if not np.all(np.isfinite(out_jac)):
Raises:
UserFunctionRuntimeError:
If any infinite or NaN values are found in the Jacobian.

"""
if not np.all(np.isfinite(out_jac)):
if origin == "jac":
msg = (
"The optimization failed because the derivative provided via "
"jac contains infinite or NaN values."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since most of the messages is the same in all three cases, you could use an f-string instead of the if-elif blocks. I.e. something like f"The optimization failed because the derivative provided via {origin} ..."

"\nPlease validate the derivative function."
)
elif origin == "fun_and_jac":
msg = (
"The optimization received Jacobian containing infinite "
"or NaN values.\nCheck your objective function or its "
"jacobian, or try a different optimizer.\n"
f"Parameters at evaluation point: {params}\n"
f"Jacobian values: {jac_value}"
"The optimization failed because the derivative provided via "
"fun_and_jac contains infinite or NaN values."
"\nPlease validate the derivative function."
)
raise UserFunctionRuntimeError(msg)
elif origin == "numerical":
msg = (
"The optimization failed because the numerical derivative "
"(computed using fun) contains infinite or NaN values."
"\nPlease validate the criterion function or try a different optimizer."
)
msg += (
f"\nParameters at evaluation point: {params}\nJacobian values: {jac_value}"
)
raise UserFunctionRuntimeError(msg)


def _process_fun_value(
Expand Down
Loading
Loading