15
15
import numpy as np
16
16
import pandas as pd
17
17
from sklearn .cluster import MiniBatchKMeans
18
- from sklearn .tree import DecisionTreeClassifier
18
+ from sklearn .tree import DecisionTreeRegressor
19
19
20
20
from mabwiser ._version import __author__ , __email__ , __version__ , __copyright__
21
21
from mabwiser .approximate import _LSHNearest
@@ -633,9 +633,9 @@ class TreeBandit(NamedTuple):
633
633
----------
634
634
tree_parameters: Dict, **kwarg
635
635
Parameters of the decision tree.
636
- The keys must match the parameters of sklearn.tree.DecisionTreeClassifier .
636
+ The keys must match the parameters of sklearn.tree.DecisionTreeRegressor .
637
637
When a parameter is not given, the default parameters from
638
- sklearn.tree.DecisionTreeClassifier will be chosen.
638
+ sklearn.tree.DecisionTreeRegressor will be chosen.
639
639
Default value is an empty dictionary.
640
640
641
641
Example
@@ -655,10 +655,10 @@ class TreeBandit(NamedTuple):
655
655
656
656
def _validate (self ):
657
657
check_true (isinstance (self .tree_parameters , dict ), TypeError ("tree_parameters must be a dictionary." ))
658
- tree = DecisionTreeClassifier ()
658
+ tree = DecisionTreeRegressor ()
659
659
for key in self .tree_parameters .keys ():
660
660
check_true (key in tree .__dict__ .keys (),
661
- ValueError ("sklearn.tree.DecisionTreeClassifier doesn't have a parameter " + str (key ) + "." ))
661
+ ValueError ("sklearn.tree.DecisionTreeRegressor doesn't have a parameter " + str (key ) + "." ))
662
662
663
663
def _is_compatible (self , learning_policy : LearningPolicy ):
664
664
# TreeBandit is compatible with these learning policies
0 commit comments