@@ -406,12 +406,12 @@ def get_configspace(name, n_classes=3, n_samples=1000, n_features=100, random_st
406
406
raise ValueError (f"Could not find configspace for { name } " )
407
407
408
408
409
- def get_search_space (name , n_classes = 3 , n_samples = 100 , n_features = 100 , random_state = None , return_choice_pipeline = True ):
409
+ def get_search_space (name , n_classes = 3 , n_samples = 100 , n_features = 100 , random_state = None , return_choice_pipeline = True , base_node = EstimatorNode ):
410
410
411
411
412
412
#if list of names, return a list of EstimatorNodes
413
413
if isinstance (name , list ) or isinstance (name , np .ndarray ):
414
- search_spaces = [get_search_space (n , n_classes = n_classes , n_samples = n_samples , n_features = n_features , random_state = random_state , return_choice_pipeline = False ) for n in name ]
414
+ search_spaces = [get_search_space (n , n_classes = n_classes , n_samples = n_samples , n_features = n_features , random_state = random_state , return_choice_pipeline = False , base_node = base_node ) for n in name ]
415
415
#remove Nones
416
416
search_spaces = [s for s in search_spaces if s is not None ]
417
417
@@ -422,12 +422,12 @@ def get_search_space(name, n_classes=3, n_samples=100, n_features=100, random_st
422
422
423
423
if name in GROUPNAMES :
424
424
name_list = GROUPNAMES [name ]
425
- return get_search_space (name_list , n_classes = n_classes , n_samples = n_samples , n_features = n_features , random_state = random_state , return_choice_pipeline = return_choice_pipeline )
425
+ return get_search_space (name_list , n_classes = n_classes , n_samples = n_samples , n_features = n_features , random_state = random_state , return_choice_pipeline = return_choice_pipeline , base_node = base_node )
426
426
427
- return get_node (name , n_classes = n_classes , n_samples = n_samples , n_features = n_features , random_state = random_state )
427
+ return get_node (name , n_classes = n_classes , n_samples = n_samples , n_features = n_features , random_state = random_state , base_node = base_node )
428
428
429
429
430
- def get_node (name , n_classes = 3 , n_samples = 100 , n_features = 100 , random_state = None ):
430
+ def get_node (name , n_classes = 3 , n_samples = 100 , n_features = 100 , random_state = None , base_node = EstimatorNode ):
431
431
432
432
#these are wrappers that take in another estimator as a parameter
433
433
# TODO Add AdaBoostRegressor, AdaBoostClassifier as wrappers? wrap a decision tree with different params?
@@ -461,39 +461,39 @@ def get_node(name, n_classes=3, n_samples=100, n_features=100, random_state=None
461
461
return EstimatorNode (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = imputers .IterativeImputer_hyperparameter_parser )
462
462
if name == "RobustScaler" :
463
463
configspace = get_configspace (name , n_classes = n_classes , n_samples = n_samples , random_state = random_state )
464
- return EstimatorNode (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = transformers .robust_scaler_hyperparameter_parser )
464
+ return base_node (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = transformers .robust_scaler_hyperparameter_parser )
465
465
if name == "GradientBoostingClassifier" :
466
466
configspace = get_configspace (name , n_classes = n_classes , n_samples = n_samples , random_state = random_state )
467
- return EstimatorNode (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = classifiers .GradientBoostingClassifier_hyperparameter_parser )
467
+ return base_node (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = classifiers .GradientBoostingClassifier_hyperparameter_parser )
468
468
if name == "HistGradientBoostingClassifier" :
469
469
configspace = get_configspace (name , n_classes = n_classes , n_samples = n_samples , random_state = random_state )
470
- return EstimatorNode (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = classifiers .HistGradientBoostingClassifier_hyperparameter_parser )
470
+ return base_node (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = classifiers .HistGradientBoostingClassifier_hyperparameter_parser )
471
471
if name == "GradientBoostingRegressor" :
472
472
configspace = get_configspace (name , n_classes = n_classes , n_samples = n_samples , random_state = random_state )
473
- return EstimatorNode (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = regressors .GradientBoostingRegressor_hyperparameter_parser )
473
+ return base_node (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = regressors .GradientBoostingRegressor_hyperparameter_parser )
474
474
if name == "HistGradientBoostingRegressor" :
475
475
configspace = get_configspace (name , n_classes = n_classes , n_samples = n_samples , random_state = random_state )
476
- return EstimatorNode (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = regressors .HistGradientBoostingRegressor_hyperparameter_parser )
476
+ return base_node (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = regressors .HistGradientBoostingRegressor_hyperparameter_parser )
477
477
if name == "MLPClassifier" :
478
478
configspace = get_configspace (name , n_classes = n_classes , n_samples = n_samples , random_state = random_state )
479
- return EstimatorNode (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = classifiers .MLPClassifier_hyperparameter_parser )
479
+ return base_node (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = classifiers .MLPClassifier_hyperparameter_parser )
480
480
if name == "MLPRegressor" :
481
481
configspace = get_configspace (name , n_classes = n_classes , n_samples = n_samples , random_state = random_state )
482
- return EstimatorNode (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = regressors .MLPRegressor_hyperparameter_parser )
482
+ return base_node (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = regressors .MLPRegressor_hyperparameter_parser )
483
483
if name == "GaussianProcessRegressor" :
484
484
configspace = get_configspace (name , n_classes = n_classes , n_samples = n_samples , random_state = random_state )
485
- return EstimatorNode (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = regressors .GaussianProcessRegressor_hyperparameter_parser )
485
+ return base_node (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = regressors .GaussianProcessRegressor_hyperparameter_parser )
486
486
if name == "GaussianProcessClassifier" :
487
487
configspace = get_configspace (name , n_classes = n_classes , n_samples = n_samples , random_state = random_state )
488
- return EstimatorNode (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = classifiers .GaussianProcessClassifier_hyperparameter_parser )
488
+ return base_node (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = classifiers .GaussianProcessClassifier_hyperparameter_parser )
489
489
if name == "FeatureAgglomeration" :
490
490
configspace = get_configspace (name , n_features = n_features )
491
- return EstimatorNode (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = transformers .FeatureAgglomeration_hyperparameter_parser )
491
+ return base_node (STRING_TO_CLASS [name ], configspace , hyperparameter_parser = transformers .FeatureAgglomeration_hyperparameter_parser )
492
492
493
493
configspace = get_configspace (name , n_classes = n_classes , n_samples = n_samples , n_features = n_features , random_state = random_state )
494
494
if configspace is None :
495
495
#raise warning
496
496
warnings .warn (f"Could not find configspace for { name } " )
497
497
return None
498
498
499
- return EstimatorNode (STRING_TO_CLASS [name ], configspace )
499
+ return base_node (STRING_TO_CLASS [name ], configspace )
0 commit comments