From dfb8586d3069beaaef0bcc3d1769fa8316d675f2 Mon Sep 17 00:00:00 2001 From: Mathieu Blondel Date: Wed, 18 Jan 2017 11:32:32 +0900 Subject: [PATCH] Fix project_simplex. When the constraint is already satisfied, there is nothing to do. --- lightning/impl/penalty.py | 3 ++ lightning/impl/tests/test_penalty.py | 42 ++++++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/lightning/impl/penalty.py b/lightning/impl/penalty.py index 6708e399..56614fdf 100644 --- a/lightning/impl/penalty.py +++ b/lightning/impl/penalty.py @@ -53,6 +53,9 @@ def regularization(self, coef): # See https://gist.github.com/mblondel/6f3b7aaad90606b98f71 # for more algorithms. def project_simplex(v, z=1): + if np.sum(v) <= z: + return v + n_features = v.shape[0] u = np.sort(v)[::-1] cssv = np.cumsum(u) - z diff --git a/lightning/impl/tests/test_penalty.py b/lightning/impl/tests/test_penalty.py index 418dc5fe..5fce5c8e 100644 --- a/lightning/impl/tests/test_penalty.py +++ b/lightning/impl/tests/test_penalty.py @@ -1,7 +1,45 @@ import numpy as np -from sklearn.utils.testing import assert_almost_equal +from sklearn.utils.testing import assert_almost_equal, assert_array_almost_equal -from lightning.impl.penalty import project_l1_ball +from lightning.impl.penalty import project_l1_ball, project_simplex + + +def project_simplex_bisection(v, z=1, tau=0.0001, max_iter=1000): + lower = 0 + upper = np.max(v) + current = np.inf + + for it in xrange(max_iter): + if np.abs(current) / z < tau and current < 0: + break + + theta = (upper + lower) / 2.0 + w = np.maximum(v - theta, 0) + current = np.sum(w) - z + if current <= 0: + upper = theta + else: + lower = theta + return w + + +def test_proj_simplex(): + rng = np.random.RandomState(0) + + v = rng.rand(100) + w = project_simplex(v, z=10) + w2 = project_simplex_bisection(v, z=10, max_iter=100) + assert_array_almost_equal(w, w2, 3) + + v = rng.rand(3) + w = project_simplex(v, z=1) + w2 = project_simplex_bisection(v, z=1, max_iter=100) + assert_array_almost_equal(w, w2, 3) + + v = rng.rand(2) + w = project_simplex(v, z=1) + w2 = project_simplex_bisection(v, z=1, max_iter=100) + assert_array_almost_equal(w, w2, 3) def test_proj_l1_ball():