Skip to content

Commit ae14e64

Browse files
committed
FIX: fixed fidelity#97 + better code organization
1 parent 2fd6986 commit ae14e64

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

mabwiser/linear.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -230,19 +230,20 @@ def _vectorized_predict_context(self, contexts: np.ndarray, is_predict: bool) ->
230230

231231
# With epsilon probability, assign random flag to context
232232
random_values = self.rng.rand(num_contexts)
233-
p_arms = np.ones(num_contexts) * self.epsilon / len(arms)
234233
random_mask = np.array(random_values < self.epsilon)
235-
p_arms[~random_mask] += 1.0 - self.epsilon
236234
random_indices = random_mask.nonzero()[0]
237235

238236
# For random indices, generate random expectations
239237
arm_expectations[random_indices] = self.rng.rand((random_indices.shape[0], len(arms)))
238+
p_arms = np.ones(num_contexts) * self.epsilon / len(arms)
240239

241240
# For non-random indices, get expectations for each arm
242241
nonrandom_indices = np.where(~random_mask)[0]
243242
nonrandom_context = contexts[nonrandom_indices]
244-
arm_expectations[nonrandom_indices] = np.array([self.arm_to_model[arm].predict(nonrandom_context)[0]
245-
for arm in arms]).T
243+
if len(nonrandom_context) > 0:
244+
arm_expectations[nonrandom_indices] = np.array([self.arm_to_model[arm].predict(nonrandom_context)
245+
for arm in arms]).T
246+
p_arms[nonrandom_indices] += 1.0 - self.epsilon
246247

247248
if is_predict:
248249
predictions = arms[argmax_2D(arm_expectations)].tolist()

0 commit comments

Comments
 (0)