diff --git a/optax/losses/_classification.py b/optax/losses/_classification.py index 5db806c7..68bda214 100644 --- a/optax/losses/_classification.py +++ b/optax/losses/_classification.py @@ -103,6 +103,37 @@ def perceptron_loss( return jnp.maximum(0, - predictor_outputs * targets) +def sparsemax_loss( + logits: chex.Array, + labels: chex.Array, +) -> chex.Array: + """Binary sparsemax loss. + + This loss is zero if and only if `jax.nn.sparse_sigmoid(logits) == labels`. + + References: + Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins, + Vlad Niculae. JMLR 2020. (Sec. 4.4) + + Args: + logits: score produced by the model (float). + labels: ground-truth integer label (0 or 1). + + Returns: + loss value + + .. versionadded:: 0.2.3 + """ + return jax.nn.sparse_plus(jnp.where(labels, -logits, logits)) + + +@functools.partial( + chex.warn_deprecated_function, + replacement='sparsemax_loss') +def binary_sparsemax_loss(logits, labels): + return sparsemax_loss(logits, labels) + + def softmax_cross_entropy( logits: chex.Array, labels: chex.Array, @@ -183,6 +214,9 @@ def multiclass_hinge_loss( ) -> chex.Array: """Multiclass hinge loss. + References: + https://en.wikipedia.org/wiki/Hinge_loss + Args: scores: scores produced by the model (floats). labels: ground-truth integer label. @@ -190,9 +224,6 @@ def multiclass_hinge_loss( Returns: loss value - References: - https://en.wikipedia.org/wiki/Hinge_loss - .. versionadded:: 0.2.3 """ one_hot_labels = jax.nn.one_hot(labels, scores.shape[-1]) diff --git a/optax/losses/_classification_test.py b/optax/losses/_classification_test.py index 86ca25a0..1ad7d9db 100644 --- a/optax/losses/_classification_test.py +++ b/optax/losses/_classification_test.py @@ -276,6 +276,43 @@ def reference_impl(label, scores): np.testing.assert_allclose(result, expected, atol=1e-4) +class SparsemaxTest(parameterized.TestCase): + + def test_binary(self): + label = 1 + score = 10. + def reference_impl(label, logit): + scores = -(2*label-1)*logit + if scores <= -1.0: + return 0.0 + elif scores >= 1.0: + return scores + else: + return (scores + 1.0) ** 2 / 4 + expected = reference_impl(label, score) + result = _classification.sparsemax_loss( + jnp.asarray(score), jnp.asarray(label)) + np.testing.assert_allclose(result, expected, atol=1e-4) + + def test_batched_binary(self): + labels = jnp.array([1, 0]) + scores = jnp.array([10., 20.]) + def reference_impl(label, logit): + scores = -(2*label-1)*logit + if scores <= -1.0: + return 0.0 + elif scores >= 1.0: + return scores + else: + return (scores + 1.0) ** 2 / 4 + expected = jnp.asarray([ + reference_impl(labels[0], scores[0]), + reference_impl(labels[1], scores[1])]) + # in the optax loss the leading dimensions are automatically handled. + result = _classification.sparsemax_loss(scores, labels) + np.testing.assert_allclose(result, expected, atol=1e-4) + + class ConvexKLDivergenceTest(parameterized.TestCase): def setUp(self):