Skip to content

Commit 646a398

Browse files
authored
Bugfixes in examples and tree params (#57)
1 parent 329125d commit 646a398

File tree

7 files changed

+18
-23
lines changed

7 files changed

+18
-23
lines changed

CHANGELOG.txt

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
MABWiser CHANGELOG
33
=====================
44

5+
March, 28, 2022 2.4.1
6+
-------------------------------------------------------------------------------
7+
minor:
8+
- Bug fixes in examples
9+
- Validate tree parameters of TreeBandit to be compatible with sklearn.tree.DecisionTreeRegressor
10+
511
March, 17, 2022 2.4.0
612
-------------------------------------------------------------------------------
713
major:

examples/context_free_mab.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383

8484
# Results
8585
print("Randomized Popularity: ", prediction, " ", expectations)
86-
assert(prediction == 2)
86+
assert(prediction == 1)
8787

8888
###################################
8989
# Softmax Learning Policy

examples/customized_mab.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def predict(self, contexts: np.ndarray=None):
120120
class LinUCBColdStart(_Linear):
121121
def __init__(self, rng, arms, n_jobs, backend, l2_lambda=1.0, alpha=1.0, features=None):
122122
# initialize the parent class as is
123-
super().__init__(rng, arms, n_jobs, backend, l2_lambda, alpha, 'ucb')
123+
super().__init__(rng, arms, n_jobs, backend, alpha, 0.0, l2_lambda, 'ucb', False)
124124

125125
# save the feature vectors
126126
self.features = features

examples/parallel_mab.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -39,24 +39,13 @@
3939
rewards_train, rewards_test = train_test_split(rewards, test_size=0.3, random_state=seed)
4040
decisions_train, decisions_test = train_test_split(decisions, test_size=0.3, random_state=seed)
4141

42-
# Fit standard scaler for each arm
43-
arm_to_scaler = {}
44-
for arm in arms:
45-
# Get indices for arm
46-
indices = np.where(decisions_train == arm)
47-
48-
# Fit standard scaler
49-
scaler = StandardScaler()
50-
scaler.fit(contexts[indices])
51-
arm_to_scaler[arm] = scaler
52-
5342
########################################################
5443
# LinUCB Learning Policy
5544
########################################################
5645

5746
# LinUCB learning policy with alpha 1.25 and n_jobs = -1 (maximum available cores)
5847
linucb = MAB(arms=arms,
59-
learning_policy=LearningPolicy.LinUCB(alpha=1.25, arm_to_scaler=arm_to_scaler),
48+
learning_policy=LearningPolicy.LinUCB(alpha=1.25, scale=True),
6049
n_jobs=-1)
6150

6251
# Learn from playlists shown and observed click rewards for each arm

mabwiser/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33

44
__author__ = "FMR LLC"
55
__email__ = "[email protected]"
6-
__version__ = "2.4.0"
6+
__version__ = "2.4.1"
77
__copyright__ = "Copyright (C), FMR LLC"

mabwiser/mab.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import numpy as np
1616
import pandas as pd
1717
from sklearn.cluster import MiniBatchKMeans
18-
from sklearn.tree import DecisionTreeClassifier
18+
from sklearn.tree import DecisionTreeRegressor
1919

2020
from mabwiser._version import __author__, __email__, __version__, __copyright__
2121
from mabwiser.approximate import _LSHNearest
@@ -633,9 +633,9 @@ class TreeBandit(NamedTuple):
633633
----------
634634
tree_parameters: Dict, **kwarg
635635
Parameters of the decision tree.
636-
The keys must match the parameters of sklearn.tree.DecisionTreeClassifier.
636+
The keys must match the parameters of sklearn.tree.DecisionTreeRegressor.
637637
When a parameter is not given, the default parameters from
638-
sklearn.tree.DecisionTreeClassifier will be chosen.
638+
sklearn.tree.DecisionTreeRegressor will be chosen.
639639
Default value is an empty dictionary.
640640
641641
Example
@@ -655,10 +655,10 @@ class TreeBandit(NamedTuple):
655655

656656
def _validate(self):
657657
check_true(isinstance(self.tree_parameters, dict), TypeError("tree_parameters must be a dictionary."))
658-
tree = DecisionTreeClassifier()
658+
tree = DecisionTreeRegressor()
659659
for key in self.tree_parameters.keys():
660660
check_true(key in tree.__dict__.keys(),
661-
ValueError("sklearn.tree.DecisionTreeClassifier doesn't have a parameter " + str(key) + "."))
661+
ValueError("sklearn.tree.DecisionTreeRegressor doesn't have a parameter " + str(key) + "."))
662662

663663
def _is_compatible(self, learning_policy: LearningPolicy):
664664
# TreeBandit is compatible with these learning policies

tests/test_base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,11 @@ def predict(arms: List[Arm],
118118
return expectations[0] if num_run == 1 else expectations, mab
119119

120120
@staticmethod
121-
def is_compatible(lp, np):
121+
def is_compatible(learning_policy, neighborhood_policy):
122122

123123
# Special case for TreeBandit lp/np compatibility
124-
if isinstance(np, NeighborhoodPolicy.TreeBandit):
125-
return np._is_compatible(lp)
124+
if isinstance(neighborhood_policy, NeighborhoodPolicy.TreeBandit):
125+
return neighborhood_policy._is_compatible(learning_policy)
126126

127127
return True
128128

0 commit comments

Comments
 (0)