Skip to content

Commit

Permalink
Merge pull request #409 from heytanay:add_hinge_loss
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 486617908
  • Loading branch information
OptaxDev committed Nov 7, 2022
2 parents 47fe655 + 84c5bba commit 509c706
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ Common Losses
cosine_similarity
ctc_loss
ctc_loss_with_forward_probs
hinge_loss
huber_loss
l2_loss
log_cosh
Expand All @@ -506,6 +507,7 @@ Losses
.. autofunction:: cosine_similarity
.. autofunction:: ctc_loss
.. autofunction:: ctc_loss_with_forward_probs
.. autofunction:: hinge_loss
.. autofunction:: huber_loss
.. autofunction:: l2_loss
.. autofunction:: log_cosh
Expand Down
2 changes: 2 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from optax._src.loss import cosine_similarity
from optax._src.loss import ctc_loss
from optax._src.loss import ctc_loss_with_forward_probs
from optax._src.loss import hinge_loss
from optax._src.loss import huber_loss
from optax._src.loss import l2_loss
from optax._src.loss import log_cosh
Expand Down Expand Up @@ -230,6 +231,7 @@
"fromage",
"global_norm",
"GradientTransformation",
"hinge_loss",
"hessian_diag",
"huber_loss",
"hvp",
Expand Down
14 changes: 14 additions & 0 deletions optax/_src/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,17 @@ def kl_divergence_with_log_targets(log_predictions: chex.Array,
chex.assert_type([log_predictions, log_targets], float)
loss = jnp.exp(log_targets) * (log_targets - log_predictions)
return jnp.sum(loss, axis=-1)


def hinge_loss(predictor_outputs: chex.Array,
targets: chex.Array) -> chex.Array:
"""Computes the hinge loss for binary classification.
Args:
predictor_outputs: Outputs of the decision function.
targets: Target values. Target values should be strictly in the set {-1, 1}.
Returns:
Binary Hinge Loss.
"""
return jnp.maximum(0, 1 - predictor_outputs * targets)
21 changes: 21 additions & 0 deletions optax/_src/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,5 +475,26 @@ def test_batched(self):
self.exp,
atol=1e-4)


class HingeLossTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.ys = np.array([
-0.97740268, -1.01812625, -0.81675726, -0.73605974, 2.08235648,
1.84101354, -1.0581002
])
self.ts = np.array([-1, -1, -1, -1, 1, 1, -1])
# Computed expected outputs.
self.correct_result = np.array(
[0.02259731, 0., 0.18324274, 0.26394027, 0., 0., 0.])

@chex.all_variants
def test_batched(self):
np.testing.assert_allclose(
self.variant(loss.hinge_loss)(self.ys, self.ts),
self.correct_result,
atol=1e-4)

if __name__ == '__main__':
absltest.main()

0 comments on commit 509c706

Please sign in to comment.