Skip to content

Commit

Permalink
Upstream sparsemax jaxopt loss to optax.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621171127
  • Loading branch information
mtthss authored and OptaxDev committed Apr 8, 2024
1 parent b0c04dc commit 4d621ff
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
37 changes: 34 additions & 3 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,37 @@ def perceptron_loss(
return jnp.maximum(0, - predictor_outputs * targets)


def sparsemax_loss(
logits: chex.Array,
labels: chex.Array,
) -> chex.Array:
"""Binary sparsemax loss.
This loss is zero if and only if `jax.nn.sparse_sigmoid(logits) == labels`.
References:
Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins,
Vlad Niculae. JMLR 2020. (Sec. 4.4)
Args:
logits: score produced by the model (float).
labels: ground-truth integer label (0 or 1).
Returns:
loss value
.. versionadded:: 0.2.3
"""
return jax.nn.sparse_plus(jnp.where(labels, -logits, logits))


@functools.partial(
chex.warn_deprecated_function,
replacement='sparsemax_loss')
def binary_sparsemax_loss(logits, labels):
return sparsemax_loss(logits, labels)


def softmax_cross_entropy(
logits: chex.Array,
labels: chex.Array,
Expand Down Expand Up @@ -183,16 +214,16 @@ def multiclass_hinge_loss(
) -> chex.Array:
"""Multiclass hinge loss.
References:
https://en.wikipedia.org/wiki/Hinge_loss
Args:
scores: scores produced by the model (floats).
labels: ground-truth integer label.
Returns:
loss value
References:
https://en.wikipedia.org/wiki/Hinge_loss
.. versionadded:: 0.2.3
"""
one_hot_labels = jax.nn.one_hot(labels, scores.shape[-1])
Expand Down
37 changes: 37 additions & 0 deletions optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,43 @@ def reference_impl(label, scores):
np.testing.assert_allclose(result, expected, atol=1e-4)


class SparsemaxTest(parameterized.TestCase):

def test_binary(self):
label = 1
score = 10.
def reference_impl(label, logit):
scores = -(2*label-1)*logit
if scores <= -1.0:
return 0.0
elif scores >= 1.0:
return scores
else:
return (scores + 1.0) ** 2 / 4
expected = reference_impl(label, score)
result = _classification.sparsemax_loss(
jnp.asarray(score), jnp.asarray(label))
np.testing.assert_allclose(result, expected, atol=1e-4)

def test_batched_binary(self):
labels = jnp.array([1, 0])
scores = jnp.array([10., 20.])
def reference_impl(label, logit):
scores = -(2*label-1)*logit
if scores <= -1.0:
return 0.0
elif scores >= 1.0:
return scores
else:
return (scores + 1.0) ** 2 / 4
expected = jnp.asarray([
reference_impl(labels[0], scores[0]),
reference_impl(labels[1], scores[1])])
# in the optax loss the leading dimensions are automatically handled.
result = _classification.sparsemax_loss(scores, labels)
np.testing.assert_allclose(result, expected, atol=1e-4)


class ConvexKLDivergenceTest(parameterized.TestCase):

def setUp(self):
Expand Down

0 comments on commit 4d621ff

Please sign in to comment.