Skip to content

Commit cad1d68

Browse files
committed
Update get_sampler
1 parent 82aa824 commit cad1d68

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

hypergbm/hyper_gbm.py

+13-16
Original file line numberDiff line numberDiff line change
@@ -38,32 +38,29 @@
3838
logger = logging.get_logger(__name__)
3939

4040
GB = 1024 ** 3
41+
SAMPLERS = {}
4142

4243
try:
4344
from imblearn.over_sampling import RandomOverSampler, SMOTE, ADASYN
4445
from imblearn.under_sampling import RandomUnderSampler, NearMiss, TomekLinks, EditedNearestNeighbours
4546

46-
imblearn_installed = True
47+
im_samplers = {'RandomOverSampler': RandomOverSampler,
48+
'SMOTE': SMOTE,
49+
'ADASYN': ADASYN,
50+
'RandomUnderSampler': RandomUnderSampler,
51+
'NearMiss': NearMiss,
52+
'TomekLinks': TomekLinks,
53+
'EditedNearestNeighbours': EditedNearestNeighbours
54+
}
55+
SAMPLERS.update(im_samplers)
4756
except:
4857
logger.warning('Failed to load imbalanced-learn', exc_info=sys.exc_info())
49-
imblearn_installed = False
5058

5159

5260
def get_sampler(sampler):
53-
if imblearn_installed:
54-
samplers = {'RandomOverSampler': RandomOverSampler,
55-
'SMOTE': SMOTE,
56-
'ADASYN': ADASYN,
57-
'RandomUnderSampler': RandomUnderSampler,
58-
'NearMiss': NearMiss,
59-
'TomekLinks': TomekLinks,
60-
'EditedNearestNeighbours': EditedNearestNeighbours
61-
}
62-
sampler_cls = samplers.get(sampler)
63-
if sampler_cls is not None:
64-
return sampler_cls()
65-
else:
66-
return None
61+
sampler_cls = SAMPLERS.get(sampler)
62+
if sampler_cls is not None:
63+
return sampler_cls()
6764
else:
6865
return None
6966

0 commit comments

Comments
 (0)