Skip to content

Commit 0e13fb3

Browse files
sdaultonfacebook-github-bot
authored andcommitted
fix EBTS bug where gen returned less than n total weight (facebook#3316)
Summary: This could happen if enought `weights` were zero here than the length of `weights` became less than `n`: https://www.internalfb.com/code/fbsource/[b9c286a9f21709eb4c964c3c24bb629b2b218c86]/fbcode/ax/models/discrete/thompson.py?lines=109-113. The `AlmostEqual` is needed to deal with numerical precision (e.g. 9.0000000002) Reviewed By: danielcohenlive Differential Revision: D69253157
1 parent 01c6a73 commit 0e13fb3

File tree

3 files changed

+4
-7
lines changed

3 files changed

+4
-7
lines changed

Diff for: ax/models/discrete/thompson.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def gen(
104104
objective_weights=objective_weights, outcome_constraints=outcome_constraints
105105
)
106106
min_weight = self.min_weight if self.min_weight is not None else 2.0 / k
107-
108107
# Second entry is used for tie-breaking
109108
weighted_arms = [
110109
(weights[i], np.random.random(), arms[i])
@@ -128,9 +127,7 @@ def gen(
128127
if self.uniform_weights:
129128
top_weights = [1.0 for _ in top_weights]
130129
else:
131-
top_weights = [
132-
(x * len(top_weights)) / sum(top_weights) for x in top_weights
133-
]
130+
top_weights = [(x * n) / sum(top_weights) for x in top_weights]
134131
return (
135132
top_arms,
136133
top_weights,

Diff for: ax/models/tests/test_eb_thompson.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_EmpiricalBayesThompsonSamplerGen(self) -> None:
7575
)
7676
self.assertEqual(arms, [[4, 4], [3, 3], [2, 2], [1, 1]])
7777
for weight, expected_weight in zip(
78-
weights, [4 * i for i in [0.66, 0.25, 0.07, 0.02]]
78+
weights, [5 * i for i in [0.66, 0.25, 0.07, 0.02]]
7979
):
8080
self.assertAlmostEqual(weight, expected_weight, delta=0.1)
8181

@@ -95,7 +95,7 @@ def test_EmpiricalBayesThompsonSamplerWarning(self) -> None:
9595
)
9696
self.assertEqual(arms, [[3, 3], [2, 2], [1, 1]])
9797
for weight, expected_weight in zip(
98-
weights, [3 * i for i in [0.74, 0.21, 0.05]]
98+
weights, [5 * i for i in [0.74, 0.21, 0.05]]
9999
):
100100
self.assertAlmostEqual(weight, expected_weight, delta=0.1)
101101

Diff for: ax/models/tests/test_thompson.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_ThompsonSamplerMinWeight(self) -> None:
105105
)
106106
self.assertEqual(arms, [[4, 4], [3, 3], [2, 2]])
107107
for weight, expected_weight in zip(
108-
weights, [3 * i for i in [0.725, 0.225, 0.05]]
108+
weights, [5 * i for i in [0.725, 0.225, 0.05]]
109109
):
110110
self.assertAlmostEqual(weight, expected_weight, 1)
111111

0 commit comments

Comments
 (0)