Skip to content

Commit

Permalink
Upstream missing jaxopt losses to optax - Part 2/N
Browse files Browse the repository at this point in the history
Adding a deprecated alias for the names that `softmax_cross_entropy_with_integer_labels` and `sigmoid_binary_cross_entropy` have in `jaxopt` to make it easier for users moving over to find the corresponding loss in optax.

PiperOrigin-RevId: 617196448
  • Loading branch information
mtthss authored and OptaxDev committed Mar 19, 2024
1 parent d6e2c30 commit f421824
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ class is an independent binary prediction and different classes are not
return -labels * log_p - (1. - labels) * log_not_p


@functools.partial(
chex.warn_deprecated_function,
replacement='sigmoid_binary_cross_entropy')
def binary_logistic_loss(logits, labels):
return sigmoid_binary_cross_entropy(logits, labels)


def hinge_loss(
predictor_outputs: chex.Array,
targets: chex.Array
Expand All @@ -76,6 +83,25 @@ def hinge_loss(
return jnp.maximum(0, 1 - predictor_outputs * targets)


def perceptron_loss(
predictor_outputs: chex.Numeric,
targets: chex.Numeric
) -> chex.Numeric:
"""Binary perceptron loss.
References:
https://en.wikipedia.org/wiki/Perceptron
Args:
predictor_outputs: score produced by the model (float).
targets: Target values. Target values should be strictly in the set {-1, 1}.
Returns:
loss value.
"""
return jnp.maximum(0, - predictor_outputs * targets)


def softmax_cross_entropy(
logits: chex.Array,
labels: chex.Array,
Expand Down Expand Up @@ -139,6 +165,13 @@ def softmax_cross_entropy_with_integer_labels(
return log_normalizers - label_logits


@functools.partial(
chex.warn_deprecated_function,
replacement='softmax_cross_entropy_with_integer_labels')
def multiclass_logistic_loss(logits, labels):
return softmax_cross_entropy_with_integer_labels(logits, labels)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def poly_loss_cross_entropy(
logits: chex.Array,
Expand Down

0 comments on commit f421824

Please sign in to comment.